Skip to content

Commit

Permalink
Merge pull request #283 from RasmusOrsoe/HE_converter_adaptation
Browse files Browse the repository at this point in the history
New i3 extractors for high energy, small bug fixes
  • Loading branch information
RasmusOrsoe authored Oct 6, 2022
2 parents ecf83e4 + 73ffa49 commit 9808766
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/graphnet/data/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
from .i3featureextractor import *
from .i3truthextractor import *
from .i3retroextractor import *
from .i3splinempeextractor import I3SplineMPEICExtractor
from .i3tumextractor import I3TUMExtractor
from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor
from .i3genericextractor import I3GenericExtractor
49 changes: 47 additions & 2 deletions src/graphnet/data/extractors/i3featureextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, pulsemap):
class I3FeatureExtractorIceCube86(I3FeatureExtractor):
def __call__(self, frame) -> dict:
"""Extract features to be used as inputs to GNN models."""

output = {
"charge": [],
"dom_time": [],
Expand All @@ -27,6 +26,10 @@ def __call__(self, frame) -> dict:
"width": [],
"pmt_area": [],
"rde": [],
"is_bright_dom": [],
"is_bad_dom": [],
"is_saturated_dom": [],
"is_errata_dom": [],
}

# Get OM data
Expand All @@ -36,6 +39,23 @@ def __call__(self, frame) -> dict:
self._calibration,
)

# Added these :
bright_doms = None
bad_doms = None
saturation_windows = None
calibration_errata = None
if "BrightDOMs" in frame:
bright_doms = frame.Get("BrightDOMs")

if "BadDomsList" in frame:
bad_doms = frame.Get("BadDomsList")

if "SaturationWindows" in frame:
saturation_windows = frame.Get("SaturationWindows")

if "CalibrationErrata" in frame:
calibration_errata = frame.Get("CalibrationErrata")

for om_key in om_keys:
# Common values for each OM
x = self._gcd_dict[om_key].position.x
Expand All @@ -44,6 +64,27 @@ def __call__(self, frame) -> dict:
area = self._gcd_dict[om_key].area
rde = self._get_relative_dom_efficiency(frame, om_key)

# DOM flags
if bright_doms:
is_bright_dom = 1 if om_key in bright_doms else 0
else:
is_bright_dom = -1

if bad_doms:
is_bad_dom = 1 if om_key in bad_doms else 0
else:
is_bad_dom = -1

if saturation_windows:
is_saturated_dom = 1 if om_key in saturation_windows else 0
else:
is_saturated_dom = -1

if calibration_errata:
is_errata_dom = 1 if om_key in calibration_errata else 0
else:
is_errata_dom = -1

# Loop over pulses for each OM
pulses = data[om_key]
for pulse in pulses:
Expand All @@ -55,7 +96,11 @@ def __call__(self, frame) -> dict:
output["dom_x"].append(x)
output["dom_y"].append(y)
output["dom_z"].append(z)

# DOM flags
output["is_bright_dom"].append(is_bright_dom)
output["is_bad_dom"].append(is_bad_dom)
output["is_saturated_dom"].append(is_saturated_dom)
output["is_errata_dom"].append(is_errata_dom)
return output

def _get_relative_dom_efficiency(self, frame, om_key):
Expand Down
40 changes: 40 additions & 0 deletions src/graphnet/data/extractors/i3hybridrecoextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from graphnet.data.extractors.i3extractor import I3Extractor


class I3GalacticPlaneHybridRecoExtractor(I3Extractor):
def __init__(self, name="dnn_hybrid"):
super().__init__(name)

def __call__(self, frame) -> dict:
"""Extracts TUMs DNN Recos and associated variables"""
output = {}
if "DNNCascadeAnalysis_version_001_p00" in frame:
reco_object = frame["DNNCascadeAnalysis_version_001_p00"]
keys = [
"angErr",
"angErr_uncorrected",
"dec",
"dpsi",
"energy",
"event",
"ra",
"run",
"subevent",
"time",
"trueDec",
"trueE",
"trueRa",
"true_azi",
"true_zen",
]
for key in keys:
output.update({key: reco_object[key]})
output.update(
{
"zenith_hybrid": reco_object["zen"],
"azimuth_hybrid": reco_object["azi"],
"energy_hybrid_log": reco_object["logE"],
}
)

return output
19 changes: 19 additions & 0 deletions src/graphnet/data/extractors/i3splinempeextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from graphnet.data.extractors.i3extractor import I3Extractor


class I3SplineMPEICExtractor(I3Extractor):
def __init__(self, name="spline_mpe_ic"):
super().__init__(name)

def __call__(self, frame) -> dict:
"""Extracts SplineMPE pointing predictions."""
output = {}
if "SplineMPEIC" in frame:
output.update(
{
"zenith_spline_mpe_ic": frame["SplineMPEIC"].dir.zenith,
"azimuth_spline_mpe_ic": frame["SplineMPEIC"].dir.azimuth,
}
)

return output
18 changes: 13 additions & 5 deletions src/graphnet/data/extractors/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def __call__(self, frame, padding_value=-1) -> dict:
# Only InIceSplit P frames contain ML appropriate I3RecoPulseSeriesMap etc.
# At low levels i3files contain several other P frame splits (e.g NullSplit),
# we remove those here.
if frame["I3EventHeader"].sub_event_stream != "InIceSplit":
if frame["I3EventHeader"].sub_event_stream not in [
"InIceSplit",
"Final",
]:
return output

if "FilterMask" in frame:
Expand Down Expand Up @@ -146,10 +149,15 @@ def __call__(self, frame, padding_value=-1) -> dict:
) = self._get_primary_particle_interaction_type_and_elasticity(
frame, sim_type
)
(
energy_track,
inelasticity,
) = self._get_primary_track_energy_and_inelasticity(frame)
try:
(
energy_track,
inelasticity,
) = self._get_primary_track_energy_and_inelasticity(frame)
except RuntimeError: # track energy fails on northeren tracks with ""Hadrons" has no mass implemented. Cannot get total energy."
energy_track = (padding_value,)
inelasticity = (padding_value,)

output.update(
{
"energy": MCInIcePrimary.energy,
Expand Down
22 changes: 22 additions & 0 deletions src/graphnet/data/extractors/i3tumextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from graphnet.data.extractors.i3extractor import I3Extractor


class I3TUMExtractor(I3Extractor):
def __init__(self, name="tum_dnn"):
super().__init__(name)

def __call__(self, frame) -> dict:
"""Extracts TUM DNN Recos and associated variables"""
output = {}
if "TUM_dnn_energy_hive" in frame:
output.update(
{
"tum_dnn_energy_hive": 10
** frame["TUM_dnn_energy_hive"]["mu_E_on_entry"],
"tum_dnn_energy_dst": 10
** frame["TUM_dnn_energy_dst"]["mu_E_on_entry"],
"tum_bdt_sigma": frame["TUM_bdt_sigma"].value,
}
)

return output

0 comments on commit 9808766

Please sign in to comment.