Skip to content

Commit

Permalink
extended VectorizedHistProxy to allow event weights and per-object we…
Browse files Browse the repository at this point in the history
…igths
  • Loading branch information
kreczko committed Aug 6, 2019
1 parent a67fdb7 commit 691dde2
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion cmsl1t/collections/vectorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@ def extend(arr1, starts, stops):

def split_input(inner_indices, x, w):
content = x
weights = w
if hasattr(x, 'starts'):
inner_indices = extend(inner_indices, x.starts, x.stops)
content = x.content
if hasattr(w, 'starts'):
weights = w.content

if np.size(weights) < np.size(content) and hasattr(x, 'starts'):
weights = extend(weights, x.starts, x.stops)

for u in np.unique(inner_indices):
mask = inner_indices == u
if not isinstance(mask, (tuple, list, np.ndarray, np.generic)):
mask = np.array(mask)
yield u, content[mask], w[mask]
yield u, content[mask], weights[mask]


class VectorizedHistCollection(BaseHistCollection):
Expand Down Expand Up @@ -147,6 +153,7 @@ def _get_hist(self, inner_index):
def fill(self, x, w=None):
if not isinstance(x, (tuple, list, np.ndarray, awkward.JaggedArray)):
x = np.array(x)

if w is None:
n = np.size(x.content) if hasattr(x, 'content') else np.size(x)
w = np.ones(n)
Expand Down

0 comments on commit 691dde2

Please sign in to comment.