Skip to content

Commit

Permalink
Merge pull request #872 from nikoladze/dev-dak-elementlinks
Browse files Browse the repository at this point in the history
fix: Daskify Elementlinks in PHYSLITE schema
  • Loading branch information
lgray authored Sep 12, 2023
2 parents 583bc1a + 6abc42c commit 47a9304
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 71 deletions.
100 changes: 76 additions & 24 deletions src/coffea/nanoevents/methods/physlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numbers import Number

import awkward
import dask_awkward
import numpy

from coffea.nanoevents.methods import base, vector
Expand Down Expand Up @@ -38,7 +39,30 @@ def _element_link(target_collection, eventindex, index, key):
return target_collection._apply_global_index(global_index)


def _element_link_method(self, link_name, target_name, _dask_array_):
if _dask_array_ is not None:
target = _dask_array_.behavior["__original_array__"]()[target_name]
links = _dask_array_[link_name]
return _element_link(
target,
_dask_array_._eventindex,
links.m_persIndex,
links.m_persKey,
)
links = self[link_name]
return _element_link(
self._events()[target_name],
self._eventindex,
links.m_persIndex,
links.m_persKey,
)


def _element_link_multiple(events, obj, link_field, with_name=None):
# currently not working in dask because:
# - we don't know the resulting type beforehand
# - also not the targets, so no way to find out which columns to load?
# - could consider to treat the case of truth collections by just loading all truth columns
link = obj[link_field]
key = link.m_persKey
index = link.m_persIndex
Expand All @@ -64,22 +88,46 @@ def where(unique_keys):
return out


def _get_target_offsets(offsets, event_index):
def _get_target_offsets(load_column, event_index):
if isinstance(load_column, dask_awkward.Array) and isinstance(
event_index, dask_awkward.Array
):
# wrap in map_partitions if dask arrays
return dask_awkward.map_partitions(
_get_target_offsets, load_column, event_index
)

offsets = load_column.layout.offsets.data

if isinstance(event_index, Number):
return offsets[event_index]

# let the necessary column optimization know that we need to load this
# column to get the offsets
if awkward.backend(load_column) == "typetracer":
awkward.typetracer.touch_data(load_column)

# necessary to stick it into the `NumpyArray` constructor
# if typetracer is passed through
offsets = awkward.typetracer.length_zero_if_typetracer(
load_column.layout.offsets.data
)

def descend(layout, depth, **kwargs):
if layout.purelist_depth == 1:
return awkward.contents.NumpyArray(offsets)[layout]

return awkward.transform(descend, event_index)
return awkward.transform(descend, event_index.layout)


def _get_global_index(target, eventindex, index):
load_column = target[
target.fields[0]
] # awkward is eager-mode now (will need to dask this)
target_offsets = _get_target_offsets(load_column.layout.offsets, eventindex)
for field in target.fields:
# fetch first column to get offsets from
# (but try to avoid the double-jagged ones if possible)
load_column = target[field]
if load_column.ndim < 3:
break
target_offsets = _get_target_offsets(load_column, eventindex)
return target_offsets + index


Expand Down Expand Up @@ -140,12 +188,12 @@ class Muon(Particle):
"""

@property
def trackParticle(self):
return _element_link(
self._events().CombinedMuonTrackParticles,
self._eventindex,
self["combinedTrackParticleLink.m_persIndex"],
self["combinedTrackParticleLink.m_persKey"],
def trackParticle(self, _dask_array_=None):
return _element_link_method(
self,
"combinedTrackParticleLink",
"CombinedMuonTrackParticles",
_dask_array_,
)


Expand All @@ -159,21 +207,25 @@ class Electron(Particle):
"""

@property
def trackParticles(self):
links = self.trackParticleLinks
return _element_link(
self._events().GSFTrackParticles,
self._eventindex,
links.m_persIndex,
links.m_persKey,
def trackParticles(self, _dask_array_=None):
return _element_link_method(
self, "trackParticleLinks", "GSFTrackParticles", _dask_array_
)

@property
def trackParticle(self):
trackParticles = self.trackParticles
return self.trackParticles[
tuple([slice(None) for i in range(trackParticles.ndim - 1)] + [0])
]
def trackParticle(self, _dask_array_=None):
trackParticles = _element_link_method(
self, "trackParticleLinks", "GSFTrackParticles", _dask_array_
)
# Ellipsis (..., 0) slicing not supported yet by dask_awkward
slicer = tuple([slice(None) for i in range(trackParticles.ndim - 1)] + [0])
return trackParticles[slicer]

@property
def caloClusters(self, _dask_array_=None):
return _element_link_method(
self, "caloClusterLinks", "CaloCalTopoClusters", _dask_array_
)


_set_repr_name("Electron")
Expand Down
1 change: 1 addition & 0 deletions src/coffea/nanoevents/schemas/physlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class PHYSLITESchema(BaseSchema):
"GSFTrackParticles": "TrackParticle",
"InDetTrackParticles": "TrackParticle",
"MuonSpectrometerTrackParticles": "TrackParticle",
"CaloCalTopoClusters": "NanoCollection",
}
"""Default configuration for mixin types, based on the collection name.
Expand Down
59 changes: 12 additions & 47 deletions tests/test_nanoevents_physlite.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import os

import numpy as np
import dask
import pytest

from coffea.nanoevents import NanoEventsFactory, PHYSLITESchema

pytestmark = pytest.mark.skip(reason="uproot is upset with this file...")


def _events():
path = os.path.abspath("tests/samples/DAOD_PHYSLITE_21.2.108.0.art.pool.root")
factory = NanoEventsFactory.from_root(
{path: "CollectionTree"},
schemaclass=PHYSLITESchema,
permit_dask=False,
permit_dask=True,
)
return factory.events()

Expand All @@ -23,54 +21,21 @@ def events():
return _events()


def test_load_single_field_of_linked(events):
with dask.config.set({"awkward.raise-failed-meta": True}):
events.Electrons.caloClusters.calE.compute()


@pytest.mark.parametrize("do_slice", [False, True])
def test_electron_track_links(events, do_slice):
if do_slice:
events = events[np.random.randint(2, size=len(events)).astype(bool)]
for event in events:
for electron in event.Electrons:
events = events[::2]
trackParticles = events.Electrons.trackParticles.compute()
for i, event in enumerate(events[["Electrons", "GSFTrackParticles"]].compute()):
for j, electron in enumerate(event.Electrons):
for link_index, link in enumerate(electron.trackParticleLinks):
track_index = link.m_persIndex
print(track_index)
print(event.GSFTrackParticles)
print(electron.trackParticleLinks)
print(electron.trackParticles)

assert (
event.GSFTrackParticles[track_index].z0
== electron.trackParticles[link_index].z0
)


# from MetaData/EventFormat
_hash_to_target_name = {
13267281: "TruthPhotons",
342174277: "TruthMuons",
368360608: "TruthNeutrinos",
375408000: "TruthTaus",
394100163: "TruthElectrons",
614719239: "TruthBoson",
660928181: "TruthTop",
779635413: "TruthBottom",
}


def test_truth_links_toplevel(events):
children_px = events.TruthBoson.children.px
for i_event, event in enumerate(events):
for i_particle, particle in enumerate(event.TruthBoson):
for i_link, link in enumerate(particle.childLinks):
assert (
event[_hash_to_target_name[link.m_persKey]][link.m_persIndex].px
== children_px[i_event][i_particle][i_link]
)


def test_truth_links(events):
for i_event, event in enumerate(events):
for i_particle, particle in enumerate(event.TruthBoson):
for i_link, link in enumerate(particle.childLinks):
assert (
event[_hash_to_target_name[link.m_persKey]][link.m_persIndex].px
== particle.children[i_link].px
== trackParticles[i][j][link_index].z0
)

0 comments on commit 47a9304

Please sign in to comment.