From cc53791322cd9152f6b6a9c7cef4cf9748e68041 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Fri, 1 Mar 2024 09:04:34 -0600 Subject: [PATCH] get all tests working again --- src/dask_histogram/boost.py | 58 ++++++++++++++++++++++++++++--------- src/dask_histogram/core.py | 2 +- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index 442ee50..25e8ded 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -19,11 +19,13 @@ from tlz import first from dask_histogram.bins import normalize_bins_range -from dask_histogram.core import ( # partitioned_factory, +from dask_histogram.core import ( AggHistogram, _get_optimization_function, _partitioned_histogram_multifill, _reduction, + factory, + is_dask_awkward_like, ) if TYPE_CHECKING: @@ -213,24 +215,52 @@ def _build_taskgraph(self): weights = [] samples = [] - for afill in self._staged: - data_list.append(afill["args"]) - weights.append(afill["kwargs"]["weight"]) - samples.append(afill["kwargs"]["sample"]) + dask_data = tuple( + datum + for datum in ( + self._staged[0]["args"] + tuple(self._staged[0]["kwargs"].values()) + ) + if is_dask_collection(datum) + ) - if all(weight is None for weight in weights): - weights = None + if is_dask_awkward_like(dask_data[0]): - if not all(sample is None for sample in samples): - samples = None + for afill in self._staged: + data_list.append(afill["args"]) + weights.append(afill["kwargs"]["weight"]) + samples.append(afill["kwargs"]["sample"]) - split_every = dask.config.get("histogram.aggregation.split_every", 8) + if all(weight is None for weight in weights): + weights = None - fills = _partitioned_histogram_multifill( - data_list, self._histref, weights, samples - ) + if not all(sample is None for sample in samples): + samples = None + + split_every = dask.config.get("histogram.aggregation.split_every", 8) + + fills = _partitioned_histogram_multifill( + data_list, self._histref, weights, samples + ) - output_hist = _reduction(fills, split_every) + output_hist = _reduction(fills, split_every) + else: + + first_fill = self._staged.pop() + + output_hist = factory( + *first_fill["args"], + histref=self._histref, + weights=first_fill["kwargs"]["weight"], + sample=first_fill["kwargs"]["sample"], + ) + + for afill in self._staged: + output_hist += factory( + *afill["args"], + histref=self._histref, + weights=afill["kwargs"]["weight"], + sample=afill["kwargs"]["sample"], + ) self._staged = None self._staged_result = output_hist diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index 04a796d..a2014f9 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -426,7 +426,7 @@ def _blocked_multi_dak( sample = None if samples is None else samples[idata] if backend != "typetracer": - thehist.fill(*tuple(data), weight=weight, sample=sample) + thehist.fill(*data, weight=weight, sample=sample) else: for datum in data: if isinstance(datum, ak.highlevel.Array):