Skip to content

Commit

Permalink
Merge branch 'main' into display_current_learning_rate
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe authored Jan 26, 2024
2 parents ea6ebc8 + b686dd7 commit 4042375
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 13 deletions.
Binary file not shown.
7 changes: 6 additions & 1 deletion src/graphnet/deployment/i3modules/graphnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from graphnet.models.graphs import GraphDefinition
from graphnet.utilities.imports import has_icecube_package
from graphnet.utilities.config import ModelConfig
from graphnet.utilities.logging import Logger

if has_icecube_package() or TYPE_CHECKING:
from icecube.icetray import (
Expand All @@ -28,7 +29,7 @@
from icecube import dataclasses, dataio, icetray


class GraphNeTI3Module:
class GraphNeTI3Module(Logger):
"""Base I3 Module for GraphNeT.
Contains methods for extracting pulsemaps, producing graphs and writing to
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
pulsemap from the I3Frames
gcd_file: Path to the associated gcd-file.
"""
super().__init__(name=__name__, class_name=self.__class__.__name__)
assert isinstance(graph_definition, GraphDefinition)
self._graph_definition = graph_definition
self._pulsemap = pulsemap
Expand Down Expand Up @@ -200,6 +202,9 @@ def __call__(self, frame: I3Frame) -> bool:
if graph is not None:
predictions = self._inference(graph)
else:
self.warning(
f"At least one event has no pulses in {self._pulsemap} - padding {self.prediction_columns} with NaN."
)
predictions = np.repeat(
[np.nan], len(self.prediction_columns)
).reshape(-1, len(self.prediction_columns))
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/models/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
and their features.
"""
from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges
from .minkowski import MinkowskiKNNEdges
98 changes: 98 additions & 0 deletions src/graphnet/models/graphs/edges/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Module containing EdgeDefinitions based on the Minkowski Metric."""
from typing import Optional, List

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
from graphnet.models.graphs.edges.edges import EdgeDefinition


def compute_minkowski_distance_mat(
x: torch.Tensor,
y: torch.Tensor,
c: float,
space_coords: Optional[List[int]] = None,
time_coord: Optional[int] = 3,
) -> torch.Tensor:
"""Compute all pairwise Minkowski distances.
Args:
x: First tensor of shape (n, d).
y: Second tensor of shape (m, d).
c: Speed of light, in scaled units.
space_coords: Indices of space coordinates.
time_coord: Index of time coordinate.
Returns: Matrix of shape (n, m) of all pairwise Minkowski distances.
"""
space_coords = space_coords or [0, 1, 2]
assert x.dim() == 2, "x must be 2-dimensional"
assert y.dim() == 2, "x must be 2-dimensional"
dist = x[:, None] - y[None, :]
pos = dist[:, :, space_coords]
time = dist[:, :, time_coord] * c
return (pos**2).sum(dim=-1) - time**2


class MinkowskiKNNEdges(EdgeDefinition):
"""Builds edges between most light-like separated."""

def __init__(
self,
nb_nearest_neighbours: int,
c: float,
time_like_weight: float = 1.0,
space_coords: Optional[List[int]] = None,
time_coord: Optional[int] = 3,
):
"""Initialize MinkowskiKNNEdges.
Args:
nb_nearest_neighbours: Number of neighbours to connect to.
c: Speed of light, in scaled units.
time_like_weight: Preference to time-like over space-like edges.
Scales time_like distances by this value, before finding
nearest neighbours.
space_coords: Coordinates of x, y, z.
time_coord: Coordinate of time.
"""
super().__init__(name=__name__, class_name=self.__class__.__name__)
self.nb_nearest_neighbours = nb_nearest_neighbours
self.c = c
self.time_like_weight = time_like_weight
self.space_coords = space_coords or [0, 1, 2]
self.time_coord = time_coord

def _construct_edges(self, graph: Data) -> Data:
x, mask = to_dense_batch(graph.x, graph.batch)
count = 0
row = []
col = []
for batch in range(x.shape[0]):
distance_mat = compute_minkowski_distance_mat(
x_masked := x[batch][mask[batch]],
x_masked,
self.c,
self.space_coords,
self.time_coord,
)
num_points = x_masked.shape[0]
num_edges = min(self.nb_nearest_neighbours, num_points)
col += [
c
for c in range(num_points)
for _ in range(count, count + num_edges)
]
distance_mat[distance_mat < 0] *= -self.time_like_weight
distance_mat += (
torch.eye(distance_mat.shape[0]) * 1e9
) # self-loops
distance_sorted = distance_mat.argsort(dim=1)
distance_sorted += count # offset by previous events
row += distance_sorted[:num_edges].flatten().tolist()
count += num_points

graph.edge_index = torch.tensor(
[row, col], dtype=torch.long, device=graph.x.device
)
return graph
Binary file not shown.
3 changes: 2 additions & 1 deletion src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,10 @@ def training_step(
batch_size=self._get_batch_size(train_batch),
prog_bar=True,
on_epoch=True,
on_step=True,
on_step=False,
sync_dist=True,
)

current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log("lr", current_lr, prog_bar=True, on_step=True)
return loss
Expand Down
29 changes: 18 additions & 11 deletions tests/deployment/queso_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os.path import join
from typing import TYPE_CHECKING, List, Sequence, Dict, Tuple, Any
import os
import numpy as np
import pytest

from graphnet.data.constants import FEATURES
Expand Down Expand Up @@ -139,16 +140,16 @@ def extract_predictions(
Returns:
Predictions from each model for each frame.
"""
file = dataio.I3File(file)
open_file = dataio.I3File(file)
data = []
while file.more(): # type: ignore
frame = file.pop_physics() # type: ignore
while open_file.more(): # type: ignore
frame = open_file.pop_physics() # type: ignore
predictions = {}
for frame_entry in frame.keys():
for model_path in model_paths:
model = model_path.split("/")[-1]
if model in frame_entry:
predictions[model] = frame[frame_entry].value
predictions[frame_entry] = frame[frame_entry].value
data.append(predictions)
return data

Expand Down Expand Up @@ -193,9 +194,7 @@ def test_deployment() -> None:
def verify_QUESO_integrity() -> None:
"""Test new and original i3 files contain same predictions."""
base_path = f"{PRETRAINED_MODEL_DIR}/icecube/upgrade/QUESO/"
queso_original_file = glob(
f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/*.i3.gz"
)[0]
queso_original_file = glob(f"{TEST_DATA_DIR}/deployment/QUESO/*.i3.gz")[0]
queso_new_file = glob(f"{TEST_DATA_DIR}/output/QUESO_test/*.i3.gz")[0]
queso_models = glob(base_path + "/*")

Expand All @@ -210,10 +209,18 @@ def verify_QUESO_integrity() -> None:
for frame in range(len(original_predictions)):
for model in original_predictions[frame].keys():
assert model in new_predictions[frame].keys()
assert (
new_predictions[frame][model]
== original_predictions[frame][model]
)
try:
assert np.isclose(
new_predictions[frame][model],
original_predictions[frame][model],
equal_nan=True,
)
except AssertionError as e:
print(
f"Mismatch found in {model}: {new_predictions[frame][model]} vs. {original_predictions[frame][model]}"
)
raise e

return


Expand Down
160 changes: 160 additions & 0 deletions tests/models/test_minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Unit tests for minkowski based edges."""
import pytest
import torch
from torch_geometric.data.data import Data

from graphnet.models.graphs.edges import KNNEdges, MinkowskiKNNEdges
from graphnet.models.graphs.edges.minkowski import (
compute_minkowski_distance_mat,
)


def test_compute_minkowski_distance_mat() -> None:
"""Testing the computation of the Minkowski distance matrix."""
vec1 = torch.tensor(
[
[
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
1.0,
],
[
1.0,
0.0,
0.0,
1.0,
],
[
1.0,
0.0,
1.0,
2.0,
],
]
)
vec2 = torch.tensor(
[
[
0.0,
0.0,
0.0,
-1.0,
],
[
1.0,
1.0,
1.0,
0.0,
],
]
)
expected11 = torch.tensor(
[
[
0.0,
0.0,
0.0,
-2.0,
],
[
0.0,
0.0,
2.0,
0.0,
],
[
0.0,
2.0,
0.0,
0.0,
],
[
-2.0,
0.0,
0.0,
0.0,
],
]
)
expected12 = torch.tensor(
[[-1.0, 3.0], [-3.0, 1.0], [-3.0, 1.0], [-7.0, -3.0]]
)
expected22 = torch.tensor(
[
[0.0, 2.0],
[2.0, 0.0],
]
)
mat11 = compute_minkowski_distance_mat(vec1, vec1, c=1.0)
mat12 = compute_minkowski_distance_mat(vec1, vec2, c=1.0)
mat22 = compute_minkowski_distance_mat(vec2, vec2, c=1.0)

assert torch.allclose(mat11, expected11)
assert torch.allclose(mat12, expected12)
assert torch.allclose(mat22, expected22)


def test_minkowski_knn_edges() -> None:
"""Testing the minkowski knn edge definition."""
data = Data(
x=torch.tensor(
[
[
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
1.0,
],
[
1.0,
0.0,
0.0,
1.0,
],
[
1.0,
0.0,
1.0,
2.0,
],
]
)
)
edge_index = MinkowskiKNNEdges(
nb_nearest_neighbours=2,
c=1.0,
)(data).edge_index
expected = torch.tensor(
[
[1, 2, 0, 3, 0, 3, 1, 2],
[0, 0, 1, 1, 2, 2, 3, 3],
]
)
assert torch.allclose(edge_index[1], expected[1])

# Allow for "permutation of connections" in edge_index[1]
assert torch.allclose(
edge_index[0, [0, 1]], expected[0, [0, 1]]
) or torch.allclose(edge_index[1, [0, 1]], expected[1, [1, 0]])
assert torch.allclose(
edge_index[0, [2, 3]], expected[0, [2, 3]]
) or torch.allclose(edge_index[1, [2, 3]], expected[1, [3, 2]])
assert torch.allclose(
edge_index[0, [4, 5]], expected[0, [4, 5]]
) or torch.allclose(edge_index[1, [4, 5]], expected[1, [5, 4]])
assert torch.allclose(
edge_index[0, [6, 7]], expected[0, [6, 7]]
) or torch.allclose(edge_index[1, [6, 7]], expected[1, [7, 6]])

0 comments on commit 4042375

Please sign in to comment.