From 691dde21c4f674a244936feec0d8493356d22634 Mon Sep 17 00:00:00 2001 From: kreczko Date: Fri, 2 Aug 2019 15:21:45 +0100 Subject: [PATCH] extended VectorizedHistProxy to allow event weights and per-object weigths --- cmsl1t/collections/vectorized.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cmsl1t/collections/vectorized.py b/cmsl1t/collections/vectorized.py index 7a776db00cb..31fea7e29f8 100644 --- a/cmsl1t/collections/vectorized.py +++ b/cmsl1t/collections/vectorized.py @@ -19,15 +19,21 @@ def extend(arr1, starts, stops): def split_input(inner_indices, x, w): content = x + weights = w if hasattr(x, 'starts'): inner_indices = extend(inner_indices, x.starts, x.stops) content = x.content + if hasattr(w, 'starts'): + weights = w.content + + if np.size(weights) < np.size(content) and hasattr(x, 'starts'): + weights = extend(weights, x.starts, x.stops) for u in np.unique(inner_indices): mask = inner_indices == u if not isinstance(mask, (tuple, list, np.ndarray, np.generic)): mask = np.array(mask) - yield u, content[mask], w[mask] + yield u, content[mask], weights[mask] class VectorizedHistCollection(BaseHistCollection): @@ -147,6 +153,7 @@ def _get_hist(self, inner_index): def fill(self, x, w=None): if not isinstance(x, (tuple, list, np.ndarray, awkward.JaggedArray)): x = np.array(x) + if w is None: n = np.size(x.content) if hasattr(x, 'content') else np.size(x) w = np.ones(n)