Skip to content

Commit

Permalink
remove ugly adding hack since we handle it differently in boost histo…
Browse files Browse the repository at this point in the history
…gram now
  • Loading branch information
lgray authored Feb 23, 2024
1 parent 9fc53af commit bcc3525
Showing 1 changed file with 1 addition and 88 deletions.
89 changes: 1 addition & 88 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,7 @@ def counts(self, flow: bool = False) -> NDArray[Any]:
def __array__(self) -> NDArray[Any]:
return self.compute().__array__()

def __iadd__(self, other: Any) -> AggHistogram:
if isinstance(other, AggHistogram):
return _iadd_hist(self, other)
def __iadd__(self, other: Any) -> AggHistogram:
return _iadd(self, other)

def __add__(self, other: Any) -> AggHistogram:
Expand Down Expand Up @@ -1058,91 +1056,6 @@ def __call__(self, a: AggHistogram, b: AggHistogram) -> AggHistogram:
return AggHistogram(g, name, histref=ref)


class UnorderedTreeReductionBinaryOp:
def __init__(
self,
func: Callable[[Any], Any],
name: str | None = None,
) -> None:
self.func = func
self.__name__ = func.__name__ if name is None else name

def __call__(self, a: AggHistogram, b: AggHistogram) -> AggHistogram:
token = tokenize(a, b)
name = f"{self.__name__}-hist-{token}"
name_comb = f"{self.__name__}-combine-{token}"
deps = []
if is_dask_collection(a):
deps.append(a)
if is_dask_collection(b):
deps.append(b)

layer_a = (
[
value
for value in a.dask.layers.values()
if isinstance(value, MockableMultiSourceTreeReduction)
]
if is_dask_collection(a)
else a
)
layer_b = (
[
value
for value in b.dask.layers.values()
if isinstance(value, MockableMultiSourceTreeReduction)
]
if is_dask_collection(b)
else b
)
layer_a = layer_a[0] if len(layer_a) == 1 else None
layer_b = layer_b[0] if len(layer_b) == 1 else None

layer_a_names = layer_a.names_inputs if layer_a else tuple()
layer_a_parts = layer_a.npartitions_inputs if layer_a else tuple()
layer_b_names = layer_b.names_inputs if layer_b else tuple()
layer_b_parts = layer_b.npartitions_inputs if layer_b else tuple()

a_b_reduction = MockableMultiSourceTreeReduction(
name=name,
names_inputs=(layer_a_names + layer_b_names),
npartitions_inputs=(layer_a_parts + layer_b_parts),
concat_func=self.func,
tree_node_func=layer_a.tree_node_func,
finalize_func=layer_a.finalize_func,
split_every=layer_a.split_every,
tree_node_name=name_comb,
)
layers = {name: a_b_reduction}
name_dep = set(a_b_reduction.names_inputs)
deps = {name: name_dep}
layers_a = a.dask.layers.copy()
if layer_a:
layers_a.pop(a.name)
layers_b = b.dask.layers.copy()
if layer_b:
layers_b.pop(b.name)
layers.update(layers_a)
layers.update(layers_b)

a_deps = a.dask.dependencies.copy()
a_deps.pop(a.name)
b_deps = b.dask.dependencies.copy()
b_deps.pop(b.name)

deps.update(a_deps)
deps.update(b_deps)

g = HighLevelGraph(layers, deps)
try:
ref = a.histref
except AttributeError:
ref = b.histref

return AggHistogram(g, name, histref=ref)


_iadd_hist = UnorderedTreeReductionBinaryOp(hist_safe_sum, "add")
_iadd = BinaryOpAgg(operator.iadd, name="add")
_isub = BinaryOpAgg(operator.isub, name="sub")
_imul = BinaryOpAgg(operator.imul, name="mul")
Expand Down

0 comments on commit bcc3525

Please sign in to comment.