Skip to content

Commit

Permalink
trial of multifill but has some stalling issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Feb 28, 2024
1 parent 8d60de2 commit d8c899a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 19 deletions.
62 changes: 44 additions & 18 deletions src/dask_histogram/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from tlz import first

from dask_histogram.bins import normalize_bins_range
from dask_histogram.core import (
from dask_histogram.core import ( # partitioned_factory,
AggHistogram,
_get_optimization_function,
_partitioned_histogram_multifill,
hist_safe_sum,
partitioned_factory,
)
from dask_histogram.layers import MockableMultiSourceTreeReduction

Expand Down Expand Up @@ -196,24 +196,50 @@ def dask(self) -> HighLevelGraph:
return self._dask

def _build_taskgraph(self):
first_args = self._staged.pop()
first_hist = partitioned_factory(
*first_args["args"], histref=self._histref, **first_args["kwargs"]
)
fills = [first_hist]
for filling_info in self._staged:
fills.append(
partitioned_factory(
*filling_info["args"],
histref=self._histref,
**filling_info["kwargs"],
)
)

label = "histreduce"
# first_args = self._staged.pop()
# first_hist = partitioned_factory(
# *first_args["args"], histref=self._histref, **first_args["kwargs"]
# )
# fills = [first_hist]
# for filling_info in self._staged:
# fills.append(
# partitioned_factory(
# *filling_info["args"],
# histref=self._histref,
# **filling_info["kwargs"],
# )
# )

data_list = []
weights = []
samples = []

for afill in self._staged:
data_list.append(afill["args"])
weights.append(afill["kwargs"]["weight"])
samples.append(afill["kwargs"]["sample"])

if all(weight is None for weight in weights):
weights = None

if not all(sample is None for sample in samples):
samples = None

split_every = dask.config.get("histogram.aggregation.split_every", 8)

# fills = [_partitioned_histogram_multifill(
# tuple(data_list[i:i+split_every]),
# self._histref,
# tuple(weights[i:i+split_every]),
# tuple(samples[i:i+split_every]),
# ) for i in range(0,len(self._staged),split_every)]

fills = [
_partitioned_histogram_multifill(data_list, self._histref, weights, samples)
]

label = "histreduce"

token = tokenize(*fills, hist_safe_sum, split_every)

name_comb = f"{label}-combine-{token}"
Expand Down Expand Up @@ -312,7 +338,7 @@ def fill( # type: ignore
raise ValueError(f"Cannot interpret input data: {args}")

# new_fill = partitioned_factory(*args, histref=self._histref, weights=weight, sample=sample)
new_fill = {"args": args, "kwargs": {"weights": weight, "sample": sample}}
new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}}
if self._staged is None:
self._staged = [new_fill]
else:
Expand Down
80 changes: 79 additions & 1 deletion src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,11 @@ def _blocked_dak_ma_w(
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
)
return thehist.fill(*tuple(thedata), weight=theweights)

if ak.backend(*data) != "typetracer":
thehist.fill(*tuple(thedata), weight=theweights)

return thehist


def _blocked_dak_ma_s(
Expand Down Expand Up @@ -401,6 +405,51 @@ def _blocked_dak_ma_w_s(
return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample)


def _blocked_multi_dak(
data_list: tuple[tuple[Any]],
weights: tuple[Any] | None,
samples: tuple[Any] | None,
histref: tuple | bh.Histogram | None = None,
) -> bh.Histogram:
import awkward as ak

thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
)

backend = ak.backend(*data_list[0])

for idata, data in enumerate(data_list):
weight = None if weights is None else weights[idata]
sample = None if samples is None else samples[idata]

thedata = [
(
ak.typetracer.length_zero_if_typetracer(datum)
if isinstance(datum, ak.Array)
else datum
)
for datum in data
]
theweight = (
ak.typetracer.length_zero_if_typetracer(weight)
if isinstance(weight, ak.Array)
else weight
)
thesample = (
ak.typetracer.length_zero_if_typetracer(sample)
if isinstance(sample, ak.Array)
else sample
)

if backend != "typetracer":
thehist.fill(*tuple(thedata), weight=theweight, sample=thesample)

return thehist


def optimize(
dsk: Mapping,
keys: Hashable | list[Hashable] | set[Hashable],
Expand Down Expand Up @@ -874,6 +923,35 @@ def _partitionwise(func, layer_name, *args, **kwargs):
)


class PackedMultifill:
def __init__(self, repacker):
self.repacker = repacker

def __call__(self, *args):
return _blocked_multi_dak(*self.repacker(args))


def _partitioned_histogram_multifill(
data: tuple[DaskCollection | tuple],
histref: bh.Histogram | tuple,
weights: tuple[DaskCollection] | None = None,
samples: tuple[DaskCollection] | None = None,
):
name = f"hist-on-block-{tokenize(data, histref, weights, samples)}"

from dask.base import unpack_collections
from dask_awkward.lib.core import partitionwise_layer as dak_pwl

flattened_deps, repacker = unpack_collections(data, weights, samples, histref)

graph = dak_pwl(PackedMultifill(repacker), name, *flattened_deps)

hlg = HighLevelGraph.from_collections(name, graph, dependencies=flattened_deps)
return PartitionedHistogram(
hlg, name, flattened_deps[0].npartitions, histref=histref
)


def _partitioned_histogram(
*data: DaskCollection,
histref: bh.Histogram | tuple,
Expand Down

0 comments on commit d8c899a

Please sign in to comment.