Skip to content

Commit

Permalink
Merge pull request graphnet-team#652 from RasmusOrsoe/update-QUESO-mo…
Browse files Browse the repository at this point in the history
…dels

Update `QUESO` energy model and unit test
  • Loading branch information
RasmusOrsoe authored Jan 26, 2024
2 parents f8d88b8 + e73ff9d commit b686dd7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
Binary file not shown.
7 changes: 6 additions & 1 deletion src/graphnet/deployment/i3modules/graphnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from graphnet.models.graphs import GraphDefinition
from graphnet.utilities.imports import has_icecube_package
from graphnet.utilities.config import ModelConfig
from graphnet.utilities.logging import Logger

if has_icecube_package() or TYPE_CHECKING:
from icecube.icetray import (
Expand All @@ -28,7 +29,7 @@
from icecube import dataclasses, dataio, icetray


class GraphNeTI3Module:
class GraphNeTI3Module(Logger):
"""Base I3 Module for GraphNeT.
Contains methods for extracting pulsemaps, producing graphs and writing to
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
pulsemap from the I3Frames
gcd_file: Path to the associated gcd-file.
"""
super().__init__(name=__name__, class_name=self.__class__.__name__)
assert isinstance(graph_definition, GraphDefinition)
self._graph_definition = graph_definition
self._pulsemap = pulsemap
Expand Down Expand Up @@ -200,6 +202,9 @@ def __call__(self, frame: I3Frame) -> bool:
if graph is not None:
predictions = self._inference(graph)
else:
self.warning(
f"At least one event has no pulses in {self._pulsemap} - padding {self.prediction_columns} with NaN."
)
predictions = np.repeat(
[np.nan], len(self.prediction_columns)
).reshape(-1, len(self.prediction_columns))
Expand Down
Binary file not shown.
29 changes: 18 additions & 11 deletions tests/deployment/queso_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os.path import join
from typing import TYPE_CHECKING, List, Sequence, Dict, Tuple, Any
import os
import numpy as np
import pytest

from graphnet.data.constants import FEATURES
Expand Down Expand Up @@ -139,16 +140,16 @@ def extract_predictions(
Returns:
Predictions from each model for each frame.
"""
file = dataio.I3File(file)
open_file = dataio.I3File(file)
data = []
while file.more(): # type: ignore
frame = file.pop_physics() # type: ignore
while open_file.more(): # type: ignore
frame = open_file.pop_physics() # type: ignore
predictions = {}
for frame_entry in frame.keys():
for model_path in model_paths:
model = model_path.split("/")[-1]
if model in frame_entry:
predictions[model] = frame[frame_entry].value
predictions[frame_entry] = frame[frame_entry].value
data.append(predictions)
return data

Expand Down Expand Up @@ -193,9 +194,7 @@ def test_deployment() -> None:
def verify_QUESO_integrity() -> None:
"""Test new and original i3 files contain same predictions."""
base_path = f"{PRETRAINED_MODEL_DIR}/icecube/upgrade/QUESO/"
queso_original_file = glob(
f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/*.i3.gz"
)[0]
queso_original_file = glob(f"{TEST_DATA_DIR}/deployment/QUESO/*.i3.gz")[0]
queso_new_file = glob(f"{TEST_DATA_DIR}/output/QUESO_test/*.i3.gz")[0]
queso_models = glob(base_path + "/*")

Expand All @@ -210,10 +209,18 @@ def verify_QUESO_integrity() -> None:
for frame in range(len(original_predictions)):
for model in original_predictions[frame].keys():
assert model in new_predictions[frame].keys()
assert (
new_predictions[frame][model]
== original_predictions[frame][model]
)
try:
assert np.isclose(
new_predictions[frame][model],
original_predictions[frame][model],
equal_nan=True,
)
except AssertionError as e:
print(
f"Mismatch found in {model}: {new_predictions[frame][model]} vs. {original_predictions[frame][model]}"
)
raise e

return


Expand Down

0 comments on commit b686dd7

Please sign in to comment.