Skip to content

Commit

Permalink
Merge pull request #630 from AMHermansen/add-minkowski-knn
Browse files Browse the repository at this point in the history
Implemented MinkowskiKNNEdges
  • Loading branch information
AMHermansen authored Jan 12, 2024
2 parents 62c5340 + 34bc2e0 commit f8d88b8
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/graphnet/models/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
and their features.
"""
from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges
from .minkowski import MinkowskiKNNEdges
98 changes: 98 additions & 0 deletions src/graphnet/models/graphs/edges/minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Module containing EdgeDefinitions based on the Minkowski Metric."""
from typing import Optional, List

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
from graphnet.models.graphs.edges.edges import EdgeDefinition


def compute_minkowski_distance_mat(
x: torch.Tensor,
y: torch.Tensor,
c: float,
space_coords: Optional[List[int]] = None,
time_coord: Optional[int] = 3,
) -> torch.Tensor:
"""Compute all pairwise Minkowski distances.
Args:
x: First tensor of shape (n, d).
y: Second tensor of shape (m, d).
c: Speed of light, in scaled units.
space_coords: Indices of space coordinates.
time_coord: Index of time coordinate.
Returns: Matrix of shape (n, m) of all pairwise Minkowski distances.
"""
space_coords = space_coords or [0, 1, 2]
assert x.dim() == 2, "x must be 2-dimensional"
assert y.dim() == 2, "x must be 2-dimensional"
dist = x[:, None] - y[None, :]
pos = dist[:, :, space_coords]
time = dist[:, :, time_coord] * c
return (pos**2).sum(dim=-1) - time**2


class MinkowskiKNNEdges(EdgeDefinition):
"""Builds edges between most light-like separated."""

def __init__(
self,
nb_nearest_neighbours: int,
c: float,
time_like_weight: float = 1.0,
space_coords: Optional[List[int]] = None,
time_coord: Optional[int] = 3,
):
"""Initialize MinkowskiKNNEdges.
Args:
nb_nearest_neighbours: Number of neighbours to connect to.
c: Speed of light, in scaled units.
time_like_weight: Preference to time-like over space-like edges.
Scales time_like distances by this value, before finding
nearest neighbours.
space_coords: Coordinates of x, y, z.
time_coord: Coordinate of time.
"""
super().__init__(name=__name__, class_name=self.__class__.__name__)
self.nb_nearest_neighbours = nb_nearest_neighbours
self.c = c
self.time_like_weight = time_like_weight
self.space_coords = space_coords or [0, 1, 2]
self.time_coord = time_coord

def _construct_edges(self, graph: Data) -> Data:
x, mask = to_dense_batch(graph.x, graph.batch)
count = 0
row = []
col = []
for batch in range(x.shape[0]):
distance_mat = compute_minkowski_distance_mat(
x_masked := x[batch][mask[batch]],
x_masked,
self.c,
self.space_coords,
self.time_coord,
)
num_points = x_masked.shape[0]
num_edges = min(self.nb_nearest_neighbours, num_points)
col += [
c
for c in range(num_points)
for _ in range(count, count + num_edges)
]
distance_mat[distance_mat < 0] *= -self.time_like_weight
distance_mat += (
torch.eye(distance_mat.shape[0]) * 1e9
) # self-loops
distance_sorted = distance_mat.argsort(dim=1)
distance_sorted += count # offset by previous events
row += distance_sorted[:num_edges].flatten().tolist()
count += num_points

graph.edge_index = torch.tensor(
[row, col], dtype=torch.long, device=graph.x.device
)
return graph
160 changes: 160 additions & 0 deletions tests/models/test_minkowski.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Unit tests for minkowski based edges."""
import pytest
import torch
from torch_geometric.data.data import Data

from graphnet.models.graphs.edges import KNNEdges, MinkowskiKNNEdges
from graphnet.models.graphs.edges.minkowski import (
compute_minkowski_distance_mat,
)


def test_compute_minkowski_distance_mat() -> None:
"""Testing the computation of the Minkowski distance matrix."""
vec1 = torch.tensor(
[
[
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
1.0,
],
[
1.0,
0.0,
0.0,
1.0,
],
[
1.0,
0.0,
1.0,
2.0,
],
]
)
vec2 = torch.tensor(
[
[
0.0,
0.0,
0.0,
-1.0,
],
[
1.0,
1.0,
1.0,
0.0,
],
]
)
expected11 = torch.tensor(
[
[
0.0,
0.0,
0.0,
-2.0,
],
[
0.0,
0.0,
2.0,
0.0,
],
[
0.0,
2.0,
0.0,
0.0,
],
[
-2.0,
0.0,
0.0,
0.0,
],
]
)
expected12 = torch.tensor(
[[-1.0, 3.0], [-3.0, 1.0], [-3.0, 1.0], [-7.0, -3.0]]
)
expected22 = torch.tensor(
[
[0.0, 2.0],
[2.0, 0.0],
]
)
mat11 = compute_minkowski_distance_mat(vec1, vec1, c=1.0)
mat12 = compute_minkowski_distance_mat(vec1, vec2, c=1.0)
mat22 = compute_minkowski_distance_mat(vec2, vec2, c=1.0)

assert torch.allclose(mat11, expected11)
assert torch.allclose(mat12, expected12)
assert torch.allclose(mat22, expected22)


def test_minkowski_knn_edges() -> None:
"""Testing the minkowski knn edge definition."""
data = Data(
x=torch.tensor(
[
[
0.0,
0.0,
0.0,
0.0,
],
[
0.0,
0.0,
1.0,
1.0,
],
[
1.0,
0.0,
0.0,
1.0,
],
[
1.0,
0.0,
1.0,
2.0,
],
]
)
)
edge_index = MinkowskiKNNEdges(
nb_nearest_neighbours=2,
c=1.0,
)(data).edge_index
expected = torch.tensor(
[
[1, 2, 0, 3, 0, 3, 1, 2],
[0, 0, 1, 1, 2, 2, 3, 3],
]
)
assert torch.allclose(edge_index[1], expected[1])

# Allow for "permutation of connections" in edge_index[1]
assert torch.allclose(
edge_index[0, [0, 1]], expected[0, [0, 1]]
) or torch.allclose(edge_index[1, [0, 1]], expected[1, [1, 0]])
assert torch.allclose(
edge_index[0, [2, 3]], expected[0, [2, 3]]
) or torch.allclose(edge_index[1, [2, 3]], expected[1, [3, 2]])
assert torch.allclose(
edge_index[0, [4, 5]], expected[0, [4, 5]]
) or torch.allclose(edge_index[1, [4, 5]], expected[1, [5, 4]])
assert torch.allclose(
edge_index[0, [6, 7]], expected[0, [6, 7]]
) or torch.allclose(edge_index[1, [6, 7]], expected[1, [7, 6]])

0 comments on commit f8d88b8

Please sign in to comment.