From 235e4d18ce0c9b48f3c50a62405d96f1b69aced5 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Sun, 28 Jan 2024 12:01:43 -0600 Subject: [PATCH 1/2] use a tree reduce for staged fills instead of pairwise adds --- src/dask_histogram/boost.py | 89 ++++++++++++++++++++++++++++++++----- src/dask_histogram/core.py | 53 ++++++++++++---------- tests/test_boost.py | 36 ++++++++------- tests/test_core.py | 32 ++++++------- 4 files changed, 145 insertions(+), 65 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index e7f8cbc..743aba8 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -8,7 +8,9 @@ import boost_histogram as bh import boost_histogram.axis as axis import boost_histogram.storage as storage +import dask import dask.array as da +from dask.bag.core import empty_safe_aggregate, partition_all from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize from dask.context import globalmethod from dask.delayed import Delayed, delayed @@ -34,6 +36,60 @@ __all__ = ("Histogram", "histogram", "histogram2d", "histogramdd") +def _hist_safe_sum(items): + safe_items = [item for item in items if not isinstance(item, tuple)] + return sum(safe_items) + + +def _build_staged_tree_reduce( + stages: list[AggHistogram], split_every: int | bool +) -> HighLevelGraph: + if not split_every: + split_every = len(stages) + + reducer = sum + + token = tokenize(stages, reducer, split_every) + + k = len(stages) + b = "" + fmt = f"staged-fill-aggregate-{token}" + depth = 0 + + dsk = {} + + if k > 1: + while k > split_every: + c = fmt + str(depth) + for i, inds in enumerate(partition_all(split_every, range(k))): + dsk[(c, i)] = ( + empty_safe_aggregate, + reducer, + [ + (stages[j].name if depth == 0 else b, 0 if depth == 0 else j) + for j in inds + ], + False, + ) + + k = i + 1 + b = c + depth += 1 + + dsk[(fmt, 0)] = ( + empty_safe_aggregate, + reducer, + [ + (stages[j].name if depth == 0 else b, 0 if depth == 0 else j) + for j in range(k) + ], + True, + ) + return fmt, HighLevelGraph.from_collections(fmt, dsk, dependencies=stages) + + return stages[0].name, stages[0].dask + + class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram): """Histogram object capable of lazy computation. @@ -46,6 +102,9 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram): type is :py:class:`boost_histogram.storage.Double`. metadata : Any Data that is passed along if a new histogram is created. + split_every : int | bool | None, default None + Width of aggregation layers for staged fills. + If False, all staged fills are added in one layer (memory intensive!). See Also -------- @@ -81,10 +140,11 @@ def __init__( *axes: bh.axis.Axis, storage: bh.storage.Storage = bh.storage.Double(), metadata: Any = None, + split_every: int | None = None, ) -> None: """Construct a Histogram object.""" super().__init__(*axes, storage=storage, metadata=metadata) - self._staged: AggHistogram | None = None + self._staged: list[AggHistogram] | None = None self._dask_name: str | None = ( f"empty-histogram-{tokenize(*axes, storage, metadata)}" ) @@ -92,12 +152,15 @@ def __init__( {self._dask_name: {(self._dask_name, 0): (lambda: self,)}}, {}, ) + self._split_every = split_every + if self._split_every is None: + self._split_every = dask.config.get("histogram.aggregation.split_every", 8) @property def _histref(self): return ( tuple(self.axes), - self.storage_type(), + self.storage_type, self.metadata, ) @@ -107,11 +170,16 @@ def __iadd__(self, other): elif not self.staged_fills() and other.staged_fills(): self._staged = other._staged if self.staged_fills(): - self._dask = self._staged.__dask_graph__() - self._dask_name = self._staged.name + new_name, new_graph = _build_staged_tree_reduce( + self._staged, self._split_every + ) + self._dask = new_graph + self._dask_name = new_name return self def __add__(self, other): + print(self) + print(other) return self.__iadd__(other) def __radd__(self, other): @@ -259,11 +327,12 @@ def fill( # type: ignore new_fill = factory(*args, histref=self._histref, weights=weight, sample=sample) if self._staged is None: - self._staged = new_fill + self._staged = [new_fill] else: - self._staged += new_fill - self._dask = self._staged.__dask_graph__() - self._dask_name = self._staged.name + self._staged += [new_fill] + new_name, new_graph = _build_staged_tree_reduce(self._staged, self._split_every) + self._dask = new_graph + self._dask_name = new_name return self @@ -321,7 +390,7 @@ def to_delayed(self) -> Delayed: """ if self._staged is not None: - return self._staged.to_delayed() + return sum(self._staged).to_delayed() return delayed(bh.Histogram(self)) def __repr__(self) -> str: @@ -389,7 +458,7 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any: """ if self._staged is not None: - return self._staged.to_dask_array(flow=flow, dd=dd) + return sum(self._staged).to_dask_array(flow=flow, dd=dd) else: counts, edges = self.to_numpy(flow=flow, dd=True, view=False) counts = da.from_array(counts) diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index afa2140..20b1bee 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -63,7 +63,7 @@ def _blocked_sa( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data) @@ -83,7 +83,7 @@ def _blocked_sa_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, sample=sample) @@ -103,7 +103,7 @@ def _blocked_sa_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, weight=weights) @@ -124,7 +124,7 @@ def _blocked_sa_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, weight=weights, sample=sample) @@ -142,7 +142,7 @@ def _blocked_ma( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*data) @@ -157,7 +157,7 @@ def _blocked_ma_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*data, sample=sample) @@ -172,7 +172,7 @@ def _blocked_ma_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*data, weight=weights) @@ -188,7 +188,7 @@ def _blocked_ma_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*data, weight=weights, sample=sample) @@ -201,7 +201,7 @@ def _blocked_df( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=None) @@ -215,7 +215,7 @@ def _blocked_df_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), sample=sample) @@ -230,7 +230,7 @@ def _blocked_df_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=weights) @@ -246,7 +246,7 @@ def _blocked_df_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=weights, sample=sample) @@ -279,7 +279,7 @@ def _blocked_dak( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(thedata, weight=theweights, sample=thesample) @@ -300,7 +300,7 @@ def _blocked_dak_ma( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*tuple(thedata)) @@ -326,7 +326,7 @@ def _blocked_dak_ma_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*tuple(thedata), weight=theweights) @@ -352,7 +352,7 @@ def _blocked_dak_ma_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*tuple(thedata), sample=thesample) @@ -383,7 +383,7 @@ def _blocked_dak_ma_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) ) return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample) @@ -507,11 +507,15 @@ def histref(self): @property def _storage_type(self) -> type[bh.storage.Storage]: """Storage type of the histogram.""" + if isinstance(self.histref, tuple): + return self.histref[1] return self.histref.storage_type @property def ndim(self) -> int: """Total number of dimensions.""" + if isinstance(self.histref, tuple): + return len(self.histref[0]) return self.histref.ndim @property @@ -738,6 +742,11 @@ def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]: return [Delayed(k, graph, layer=layer) for k in keys] +def _hist_safe_sum(items): + safe_items = [item for item in items if not isinstance(item, tuple)] + return sum(safe_items) + + def _reduction( ph: PartitionedHistogram, split_every: int | None = None, @@ -754,15 +763,11 @@ def _reduction( name_comb = f"{label}-combine-{token}" name_agg = f"{label}-agg-{token}" - def hist_safe_sum(items): - safe_items = [item for item in items if not isinstance(item, tuple)] - return sum(safe_items) - mdftr = MockableDataFrameTreeReduction( name=name_agg, name_input=ph.name, npartitions_input=ph.npartitions, - concat_func=hist_safe_sum, + concat_func=_hist_safe_sum, tree_node_func=lambda x: x, finalize_func=lambda x: x, split_every=split_every, @@ -992,7 +997,9 @@ def to_dask_array(agghist: AggHistogram, flow: bool = False, dd: bool = False) - thehist = agghist.histref if isinstance(thehist, tuple): thehist = bh.Histogram( - *agghist.histref[0], storage=agghist.histref[1], metadata=agghist.histref[2] + *agghist.histref[0], + storage=agghist.histref[1](), + metadata=agghist.histref[2], ) zeros = (0,) * thehist.ndim dsk = {(name, *zeros): (lambda x, f: x.to_numpy(flow=f)[0], agghist.key, flow)} diff --git a/tests/test_boost.py b/tests/test_boost.py index a7770bf..fdd9cef 100644 --- a/tests/test_boost.py +++ b/tests/test_boost.py @@ -127,26 +127,28 @@ def test_obj_5D_strcat_intcat_rectangular(use_weights): dhb.axis.Regular(9, -3.2, 3.2), storage=storage, ) - h.fill("testcat1", 1, *(x.T), weight=weights) - h.fill("testcat2", 2, *(x.T), weight=weights) + for i in range(25): + h.fill(f"testcat{i+1}", i + 1, *(x.T), weight=weights) h = h.compute() control = bh.Histogram(*h.axes, storage=h.storage_type()) if use_weights: - control.fill("testcat1", 1, *(x.compute().T), weight=weights.compute()) - control.fill("testcat2", 2, *(x.compute().T), weight=weights.compute()) + for i in range(25): + control.fill( + f"testcat{i+1}", i + 1, *(x.compute().T), weight=weights.compute() + ) else: - control.fill("testcat1", 1, *(x.compute().T)) - control.fill("testcat2", 2, *(x.compute().T)) + for i in range(25): + control.fill(f"testcat{i+1}", i + 1, *(x.compute().T)) assert np.allclose(h.counts(), control.counts()) if use_weights: assert np.allclose(h.variances(), control.variances()) - assert len(h.axes[0]) == 2 and len(control.axes[0]) == 2 + assert len(h.axes[0]) == 25 and len(control.axes[0]) == 25 assert all(cx == hx for cx, hx in zip(control.axes[0], h.axes[0])) - assert len(h.axes[1]) == 2 and len(control.axes[1]) == 2 + assert len(h.axes[1]) == 25 and len(control.axes[1]) == 25 assert all(cx == hx for cx, hx in zip(control.axes[1], h.axes[1])) @@ -174,27 +176,29 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights): dhb.axis.Regular(9, -3.2, 3.2), storage=storage, ) - h.fill("testcat1", 1, x, y, z, weight=weights) - h.fill("testcat2", 2, x, y, z, weight=weights) + for i in range(25): + h.fill(f"testcat{i+1}", i + 1, x, y, z, weight=weights) h = h.compute() control = bh.Histogram(*h.axes, storage=h.storage_type()) x_c, y_c, z_c = x.compute(), y.compute(), z.compute() if use_weights: - control.fill("testcat1", 1, x_c, y_c, z_c, weight=weights.compute()) - control.fill("testcat2", 2, x_c, y_c, z_c, weight=weights.compute()) + for i in range(25): + control.fill( + f"testcat{i+1}", i + 1, x_c, y_c, z_c, weight=weights.compute() + ) else: - control.fill("testcat1", 1, x_c, y_c, z_c) - control.fill("testcat2", 2, x_c, y_c, z_c) + for i in range(25): + control.fill(f"testcat{i+1}", i + 1, x_c, y_c, z_c) assert np.allclose(h.counts(), control.counts()) if use_weights: assert np.allclose(h.variances(), control.variances()) - assert len(h.axes[0]) == 2 and len(control.axes[0]) == 2 + assert len(h.axes[0]) == 25 and len(control.axes[0]) == 25 assert all(cx == hx for cx, hx in zip(control.axes[0], h.axes[0])) - assert len(h.axes[1]) == 2 and len(control.axes[1]) == 2 + assert len(h.axes[1]) == 25 and len(control.axes[1]) == 25 assert all(cx == hx for cx, hx in zip(control.axes[1], h.axes[1])) diff --git a/tests/test_core.py b/tests/test_core.py index e76061c..60459af 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,13 +12,13 @@ def _gen_storage(weights, sample): if weights is not None and sample is not None: - store = bh.storage.WeightedMean() + store = bh.storage.WeightedMean elif weights is None and sample is not None: - store = bh.storage.Mean() + store = bh.storage.Mean elif weights is not None and sample is None: - store = bh.storage.Weight() + store = bh.storage.Weight else: - store = bh.storage.Double() + store = bh.storage.Double return store @@ -31,7 +31,7 @@ def test_1d_array(weights, sample): sample = da.random.uniform(2, 8, size=(2000,), chunks=(250,)) store = _gen_storage(weights, sample) histref = ((bh.axis.Regular(10, -3, 3),), store, None) - h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) x = da.random.standard_normal(size=(2000,), chunks=(250,)) dh = dhc.factory(x, histref=histref, weights=weights, split_every=4, sample=sample) h.fill( @@ -59,7 +59,7 @@ def test_array_input(weights, shape, sample): sample = da.random.uniform(3, 9, size=(2000,), chunks=(200,)) store = _gen_storage(weights, sample) histref = (axes, store, None) - h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) dh = dhc.factory(x, histref=histref, weights=weights, split_every=4, sample=sample) h.fill( *xc, @@ -76,12 +76,12 @@ def test_multi_array(weights): bh.axis.Regular(10, -3, 3), bh.axis.Regular(10, -3, 3), ), - bh.storage.Weight(), + bh.storage.Weight, None, ) h = bh.Histogram( *histref[0], - storage=histref[1], + storage=histref[1](), metadata=histref[2], ) if weights is not None: @@ -105,12 +105,12 @@ def test_nd_array(weights): bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1), ), - bh.storage.Weight(), + bh.storage.Weight, None, ) h = bh.Histogram( *histref[0], - storage=histref[1], + storage=histref[1](), metadata=histref[2], ) if weights is not None: @@ -134,10 +134,10 @@ def test_df_input(weights): bh.axis.Regular(12, 0, 1), bh.axis.Regular(12, 0, 1), ), - bh.storage.Weight(), + bh.storage.Weight, None, ) - h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) df = dds.timeseries(freq="600s", partition_freq="2d") dfc = df.compute() if weights is not None: @@ -166,7 +166,7 @@ def test_to_dask_array(weights, shape): ) h = bh.Histogram(*axes, storage=bh.storage.Weight()) dh = dhc.factory( - x, histref=(axes, bh.storage.Weight(), None), weights=weights, split_every=4 + x, histref=(axes, bh.storage.Weight, None), weights=weights, split_every=4 ) h.fill(*xc, weight=weights.compute() if weights is not None else None) c, _ = dh.to_dask_array(flow=False, dd=True) @@ -181,7 +181,7 @@ def gen_hist_1D( ) -> dhc.AggHistogram: histref = ( (bh.axis.Regular(bins, range[0], range[1]),), - bh.storage.Weight(), + bh.storage.Weight, None, ) x = da.random.standard_normal(size=size, chunks=chunks) @@ -319,12 +319,12 @@ def test_agghist_to_delayed(weights): bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1), ), - bh.storage.Weight(), + bh.storage.Weight, None, ) h = bh.Histogram( *histref[0], - storage=histref[1], + storage=histref[1](), metadata=histref[2], ) if weights is not None: From 7740f00a978407324a2e61df60681be3bec120cd Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 29 Jan 2024 16:31:55 -0500 Subject: [PATCH 2/2] address comments --- src/dask_histogram/boost.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index 743aba8..c828326 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -36,11 +36,6 @@ __all__ = ("Histogram", "histogram", "histogram2d", "histogramdd") -def _hist_safe_sum(items): - safe_items = [item for item in items if not isinstance(item, tuple)] - return sum(safe_items) - - def _build_staged_tree_reduce( stages: list[AggHistogram], split_every: int | bool ) -> HighLevelGraph: @@ -178,8 +173,6 @@ def __iadd__(self, other): return self def __add__(self, other): - print(self) - print(other) return self.__iadd__(other) def __radd__(self, other): @@ -390,7 +383,7 @@ def to_delayed(self) -> Delayed: """ if self._staged is not None: - return sum(self._staged).to_delayed() + return sum(self._staged[1:], start=self._staged[0]).to_delayed() return delayed(bh.Histogram(self)) def __repr__(self) -> str: