Skip to content

Commit

Permalink
Store histref in a partial function for propagation through graph. (#157
Browse files Browse the repository at this point in the history
)



Co-authored-by: Dmitry Kalinkin <[email protected]>
  • Loading branch information
douglasdavis and veprbl authored Dec 19, 2024
1 parent bf85767 commit f7d236e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
Empty file removed afile
Empty file.
7 changes: 3 additions & 4 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def _blocked_multi(
repacker: Callable,
*flattened_inputs: tuple[Any],
) -> bh.Histogram:

data_list, weights, samples, histref = repacker(flattened_inputs)

weights = weights or (None for _ in range(len(data_list)))
Expand Down Expand Up @@ -439,7 +438,6 @@ def _blocked_multi_df(
repacker: Callable,
*flattened_inputs: tuple[Any],
) -> bh.Histogram:

data_list, weights, samples, histref = repacker(flattened_inputs)

weights = weights or (None for _ in range(len(data_list)))
Expand Down Expand Up @@ -1027,8 +1025,9 @@ def _partitioned_histogram(
if len(data) == 1 and data_is_dak:
from dask_awkward.lib.core import partitionwise_layer as dak_pwl

f = partial(_blocked_dak, weights=weights, sample=sample, histref=histref)
g = dak_pwl(f, name, data[0])
f = partial(_blocked_dak, histref=histref)

g = dak_pwl(f, name, data[0], weights, sample)

# Single object, not a dataframe
elif len(data) == 1 and not data_is_df:
Expand Down
54 changes: 53 additions & 1 deletion tests/test_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@ def test_155_boost_factory():
import boost_histogram as bh

dak = pytest.importorskip("dask_awkward")
import numpy as np

import dask_histogram as dh

Expand All @@ -584,3 +583,56 @@ def test_155_boost_factory():
axes=(axis,),
).compute()
assert np.all(hist.values() == [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0])


def test_155_2():
import boost_histogram as bh

import dask_histogram as dh

dak = pytest.importorskip("dask_awkward")

arr = dak.from_lists([list(range(10))] * 3)
axis = bh.axis.Regular(10, 0.0, 10.0)
hist = dh.factory(
arr,
axes=(axis,),
weights=arr,
).compute()
assert np.all(
hist.values() == [0.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0]
)


def test_155_3_2d():
import boost_histogram as bh

dak = pytest.importorskip("dask_awkward")

import dask_histogram as dh

arr1 = dak.from_lists([list(range(10))] * 3)
arr2 = dak.from_lists([list(reversed(range(10)))] * 3)
axis1 = bh.axis.Regular(10, 0.0, 10.0)
axis2 = bh.axis.Regular(10, 0.0, 10.0)
hist = dh.factory(
arr1,
arr2,
axes=(axis1, axis2),
weights=arr1,
).compute()
should_be = (
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 18.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 21.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[27.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
)
assert np.all(hist.values() == should_be)

0 comments on commit f7d236e

Please sign in to comment.