Skip to content

Commit

Permalink
get all tests working again
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 1, 2024
1 parent 782de36 commit cc53791
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
58 changes: 44 additions & 14 deletions src/dask_histogram/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from tlz import first

from dask_histogram.bins import normalize_bins_range
from dask_histogram.core import ( # partitioned_factory,
from dask_histogram.core import (
AggHistogram,
_get_optimization_function,
_partitioned_histogram_multifill,
_reduction,
factory,
is_dask_awkward_like,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -213,24 +215,52 @@ def _build_taskgraph(self):
weights = []
samples = []

for afill in self._staged:
data_list.append(afill["args"])
weights.append(afill["kwargs"]["weight"])
samples.append(afill["kwargs"]["sample"])
dask_data = tuple(
datum
for datum in (
self._staged[0]["args"] + tuple(self._staged[0]["kwargs"].values())
)
if is_dask_collection(datum)
)

if all(weight is None for weight in weights):
weights = None
if is_dask_awkward_like(dask_data[0]):

if not all(sample is None for sample in samples):
samples = None
for afill in self._staged:
data_list.append(afill["args"])
weights.append(afill["kwargs"]["weight"])
samples.append(afill["kwargs"]["sample"])

split_every = dask.config.get("histogram.aggregation.split_every", 8)
if all(weight is None for weight in weights):
weights = None

fills = _partitioned_histogram_multifill(
data_list, self._histref, weights, samples
)
if not all(sample is None for sample in samples):
samples = None

split_every = dask.config.get("histogram.aggregation.split_every", 8)

fills = _partitioned_histogram_multifill(
data_list, self._histref, weights, samples
)

output_hist = _reduction(fills, split_every)
output_hist = _reduction(fills, split_every)
else:

first_fill = self._staged.pop()

output_hist = factory(
*first_fill["args"],
histref=self._histref,
weights=first_fill["kwargs"]["weight"],
sample=first_fill["kwargs"]["sample"],
)

for afill in self._staged:
output_hist += factory(
*afill["args"],
histref=self._histref,
weights=afill["kwargs"]["weight"],
sample=afill["kwargs"]["sample"],
)

self._staged = None
self._staged_result = output_hist
Expand Down
2 changes: 1 addition & 1 deletion src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _blocked_multi_dak(
sample = None if samples is None else samples[idata]

if backend != "typetracer":
thehist.fill(*tuple(data), weight=weight, sample=sample)
thehist.fill(*data, weight=weight, sample=sample)
else:
for datum in data:
if isinstance(datum, ak.highlevel.Array):
Expand Down

0 comments on commit cc53791

Please sign in to comment.