diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index 98a7825..df36274 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -19,11 +19,11 @@ from tlz import first from dask_histogram.bins import normalize_bins_range -from dask_histogram.core import ( +from dask_histogram.core import ( # partitioned_factory, AggHistogram, _get_optimization_function, + _partitioned_histogram_multifill, hist_safe_sum, - partitioned_factory, ) from dask_histogram.layers import MockableMultiSourceTreeReduction @@ -196,24 +196,50 @@ def dask(self) -> HighLevelGraph: return self._dask def _build_taskgraph(self): - first_args = self._staged.pop() - first_hist = partitioned_factory( - *first_args["args"], histref=self._histref, **first_args["kwargs"] - ) - fills = [first_hist] - for filling_info in self._staged: - fills.append( - partitioned_factory( - *filling_info["args"], - histref=self._histref, - **filling_info["kwargs"], - ) - ) - - label = "histreduce" + # first_args = self._staged.pop() + # first_hist = partitioned_factory( + # *first_args["args"], histref=self._histref, **first_args["kwargs"] + # ) + # fills = [first_hist] + # for filling_info in self._staged: + # fills.append( + # partitioned_factory( + # *filling_info["args"], + # histref=self._histref, + # **filling_info["kwargs"], + # ) + # ) + + data_list = [] + weights = [] + samples = [] + + for afill in self._staged: + data_list.append(afill["args"]) + weights.append(afill["kwargs"]["weight"]) + samples.append(afill["kwargs"]["sample"]) + + if all(weight is None for weight in weights): + weights = None + + if not all(sample is None for sample in samples): + samples = None split_every = dask.config.get("histogram.aggregation.split_every", 8) + # fills = [_partitioned_histogram_multifill( + # tuple(data_list[i:i+split_every]), + # self._histref, + # tuple(weights[i:i+split_every]), + # tuple(samples[i:i+split_every]), + # ) for i in range(0,len(self._staged),split_every)] + + fills = [ + _partitioned_histogram_multifill(data_list, self._histref, weights, samples) + ] + + label = "histreduce" + token = tokenize(*fills, hist_safe_sum, split_every) name_comb = f"{label}-combine-{token}" @@ -312,7 +338,7 @@ def fill( # type: ignore raise ValueError(f"Cannot interpret input data: {args}") # new_fill = partitioned_factory(*args, histref=self._histref, weights=weight, sample=sample) - new_fill = {"args": args, "kwargs": {"weights": weight, "sample": sample}} + new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}} if self._staged is None: self._staged = [new_fill] else: diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index b82f063..6b0f2eb 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -337,7 +337,11 @@ def _blocked_dak_ma_w( if not isinstance(histref, tuple) else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) - return thehist.fill(*tuple(thedata), weight=theweights) + + if ak.backend(*data) != "typetracer": + thehist.fill(*tuple(thedata), weight=theweights) + + return thehist def _blocked_dak_ma_s( @@ -401,6 +405,51 @@ def _blocked_dak_ma_w_s( return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample) +def _blocked_multi_dak( + data_list: tuple[tuple[Any]], + weights: tuple[Any] | None, + samples: tuple[Any] | None, + histref: tuple | bh.Histogram | None = None, +) -> bh.Histogram: + import awkward as ak + + thehist = ( + clone(histref) + if not isinstance(histref, tuple) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + ) + + backend = ak.backend(*data_list[0]) + + for idata, data in enumerate(data_list): + weight = None if weights is None else weights[idata] + sample = None if samples is None else samples[idata] + + thedata = [ + ( + ak.typetracer.length_zero_if_typetracer(datum) + if isinstance(datum, ak.Array) + else datum + ) + for datum in data + ] + theweight = ( + ak.typetracer.length_zero_if_typetracer(weight) + if isinstance(weight, ak.Array) + else weight + ) + thesample = ( + ak.typetracer.length_zero_if_typetracer(sample) + if isinstance(sample, ak.Array) + else sample + ) + + if backend != "typetracer": + thehist.fill(*tuple(thedata), weight=theweight, sample=thesample) + + return thehist + + def optimize( dsk: Mapping, keys: Hashable | list[Hashable] | set[Hashable], @@ -874,6 +923,35 @@ def _partitionwise(func, layer_name, *args, **kwargs): ) +class PackedMultifill: + def __init__(self, repacker): + self.repacker = repacker + + def __call__(self, *args): + return _blocked_multi_dak(*self.repacker(args)) + + +def _partitioned_histogram_multifill( + data: tuple[DaskCollection | tuple], + histref: bh.Histogram | tuple, + weights: tuple[DaskCollection] | None = None, + samples: tuple[DaskCollection] | None = None, +): + name = f"hist-on-block-{tokenize(data, histref, weights, samples)}" + + from dask.base import unpack_collections + from dask_awkward.lib.core import partitionwise_layer as dak_pwl + + flattened_deps, repacker = unpack_collections(data, weights, samples, histref) + + graph = dak_pwl(PackedMultifill(repacker), name, *flattened_deps) + + hlg = HighLevelGraph.from_collections(name, graph, dependencies=flattened_deps) + return PartitionedHistogram( + hlg, name, flattened_deps[0].npartitions, histref=histref + ) + + def _partitioned_histogram( *data: DaskCollection, histref: bh.Histogram | tuple,