From 1cec63d46356458b8abbdf56f58f408718327900 Mon Sep 17 00:00:00 2001 From: amhermansen Date: Wed, 15 Nov 2023 17:31:42 +0100 Subject: [PATCH 1/7] Implemented MinkowskiKNNEdges --- src/graphnet/models/graphs/edges/__init__.py | 1 + src/graphnet/models/graphs/edges/minkowski.py | 100 ++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 src/graphnet/models/graphs/edges/minkowski.py diff --git a/src/graphnet/models/graphs/edges/__init__.py b/src/graphnet/models/graphs/edges/__init__.py index 7da8baa7c..40c8bbeab 100644 --- a/src/graphnet/models/graphs/edges/__init__.py +++ b/src/graphnet/models/graphs/edges/__init__.py @@ -5,3 +5,4 @@ and their features. """ from .edges import EdgeDefinition, KNNEdges, RadialEdges, EuclideanEdges +from .minkowski import MinkowskiKNNEdges diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py new file mode 100644 index 000000000..583b4a0e7 --- /dev/null +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -0,0 +1,100 @@ +"""Module containing EdgeDefinitions based on the Minkowski Metric.""" +from typing import Optional, List + +import torch +from torch_geometric.data import Data +from torch_geometric.utils import to_dense_batch +from graphnet.models.graphs.edges import EdgeDefinition + + +def compute_minkowski_distance_mat( + x: torch.Tensor, + y: torch.Tensor, + c: float, + space_coords: Optional[List[int]] = None, + time_coord: Optional[int] = 3, +) -> torch.Tensor: + """Compute all pairwise Minkowski distances. + + Args: + x: First tensor of shape (n, d). + y: Second tensor of shape (m, d). + c: Speed of light, in scaled units. + space_coords: Indices of space coordinates. + time_coord: Index of time coordinate. + + Returns: Matrix of shape (n, m) of all pairwise Minkowski distances. + """ + assert x.shape == y.shape, "x and y must have the same shape" + assert x.dim() == 2, "x and y must be 2-dimensional" + dist = x[:, None] - y[None, :] + pos = dist[:, :, space_coords] + time = dist[:, :, time_coord] * c + return (pos**2).sum(dim=-1) - time**2 + + +# TODO: Replace use of MinkowskiKNNEdges with +# custom Cuda/cpp kernel for reduced memory usage. +# Currently, O(n^2) memory is used, but O(n*k) is possible. +# Where n is the number of points in the largest event, +# and k is the number of neighbours to connect to. +class MinkowskiKNNEdges(EdgeDefinition): + """Builds edges between most light-like separated.""" + + def __init__( + self, + nb_nearest_neighbours: int, + c: float, + time_like_weight: float = 1.0, + space_coords: Optional[List[int]] = None, + time_coord: Optional[int] = 3, + ): + """Initialize MinkowskiKNNEdges. + + Args: + nb_nearest_neighbours: Number of neighbours to connect to. + c: Speed of light, in scaled units. + time_like_weight: Preference to time-like over space-like edges. + Scales time_like distances by this value, before finding + nearest neighbours. + space_coords: Coordinates of x, y, z. + time_coord: Coordinate of time. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + self.nb_nearest_neighbours = nb_nearest_neighbours + self.c = c + self.time_like_weight = time_like_weight + self.space_coords = space_coords or [0, 1, 2] + self.time_coord = time_coord + + def _construct_edges(self, graph: Data) -> Data: + x, mask = to_dense_batch(graph.x, graph.batch) + count = 0 + row = [] + col = [] + for batch in range(x.shape[0]): + distance_mat = compute_minkowski_distance_mat( + x_masked := x[batch][mask[batch]], + x_masked, + self.c, + self.space_coords, + self.time_coord, + ) + num_points = x_masked.shape[0] + num_edges = min(self.nb_nearest_neighbours, num_points) + col += [ + c + for c in range(count, count + num_edges) + for _ in range(num_points) + ] + distance_mat[distance_mat < 0] *= self.time_like_weight + + distance_sorted = distance_mat.argsort(dim=1) + distance_sorted += count # offset by previous events + row += distance_sorted[:num_edges].flatten().tolist() + count += num_points + + graph.edge_index = torch.tensor( + [row, col], dtype=torch.long, device=graph.x.device + ) + return graph From 1ec436896efa7ed06d3015a2122d5e581bfd7709 Mon Sep 17 00:00:00 2001 From: Andreas Michael Hermansen <97125645+AMHermansen@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:45:18 +0100 Subject: [PATCH 2/7] Swapped row/col There was a bug where source/target was swapped. --- src/graphnet/models/graphs/edges/minkowski.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 583b4a0e7..8ecb6f197 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -82,7 +82,7 @@ def _construct_edges(self, graph: Data) -> Data: ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) - col += [ + row += [ c for c in range(count, count + num_edges) for _ in range(num_points) @@ -91,7 +91,7 @@ def _construct_edges(self, graph: Data) -> Data: distance_sorted = distance_mat.argsort(dim=1) distance_sorted += count # offset by previous events - row += distance_sorted[:num_edges].flatten().tolist() + col += distance_sorted[:num_edges].flatten().tolist() count += num_points graph.edge_index = torch.tensor( From 594adacbeb5cfb56d7d279d11703c1a448a58ece Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Tue, 12 Dec 2023 21:00:44 +0100 Subject: [PATCH 3/7] Bugfix --- src/graphnet/models/graphs/edges/minkowski.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 8ecb6f197..17f624636 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -4,7 +4,7 @@ import torch from torch_geometric.data import Data from torch_geometric.utils import to_dense_batch -from graphnet.models.graphs.edges import EdgeDefinition +from graphnet.models.graphs.edges.edges import EdgeDefinition def compute_minkowski_distance_mat( @@ -25,8 +25,9 @@ def compute_minkowski_distance_mat( Returns: Matrix of shape (n, m) of all pairwise Minkowski distances. """ - assert x.shape == y.shape, "x and y must have the same shape" - assert x.dim() == 2, "x and y must be 2-dimensional" + space_coords = space_coords or [0, 1, 2] + assert x.dim() == 2, "x must be 2-dimensional" + assert y.dim() == 2, "x must be 2-dimensional" dist = x[:, None] - y[None, :] pos = dist[:, :, space_coords] time = dist[:, :, time_coord] * c @@ -84,11 +85,13 @@ def _construct_edges(self, graph: Data) -> Data: num_edges = min(self.nb_nearest_neighbours, num_points) row += [ c - for c in range(count, count + num_edges) - for _ in range(num_points) + for c in range(num_points) + for _ in range(count, count + num_edges) ] - distance_mat[distance_mat < 0] *= self.time_like_weight - + distance_mat[distance_mat < 0] *= -self.time_like_weight + distance_mat += ( + torch.eye(distance_mat.shape[0]) * 1e9 + ) # self-loops distance_sorted = distance_mat.argsort(dim=1) distance_sorted += count # offset by previous events col += distance_sorted[:num_edges].flatten().tolist() From e7ffd6d440bc06cfef5194225a9c84f2bdb89c6b Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Tue, 12 Dec 2023 21:01:31 +0100 Subject: [PATCH 4/7] Removed optimization todo comment --- src/graphnet/models/graphs/edges/minkowski.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 17f624636..00fb440d7 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -34,11 +34,6 @@ def compute_minkowski_distance_mat( return (pos**2).sum(dim=-1) - time**2 -# TODO: Replace use of MinkowskiKNNEdges with -# custom Cuda/cpp kernel for reduced memory usage. -# Currently, O(n^2) memory is used, but O(n*k) is possible. -# Where n is the number of points in the largest event, -# and k is the number of neighbours to connect to. class MinkowskiKNNEdges(EdgeDefinition): """Builds edges between most light-like separated.""" From ca12930f7c314bcbdffecac386836be0699de577 Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Tue, 12 Dec 2023 21:02:31 +0100 Subject: [PATCH 5/7] Added unit tests for minkowski --- tests/models/test_minkowski.py | 160 +++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 tests/models/test_minkowski.py diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py new file mode 100644 index 000000000..4a42ca25b --- /dev/null +++ b/tests/models/test_minkowski.py @@ -0,0 +1,160 @@ +"""Unit tests for minkowski based edges.""" +import pytest +import torch +from torch_geometric.data.data import Data + +from graphnet.models.graphs.edges import KNNEdges, MinkowskiKNNEdges +from graphnet.models.graphs.edges.minkowski import ( + compute_minkowski_distance_mat, +) + + +def test_compute_minkowski_distance_mat() -> None: + """Testing the computation of the Minkowski distance matrix.""" + vec1 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 1.0, + 0.0, + 0.0, + 1.0, + ], + [ + 1.0, + 0.0, + 1.0, + 2.0, + ], + ] + ) + vec2 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + -1.0, + ], + [ + 1.0, + 1.0, + 1.0, + 0.0, + ], + ] + ) + expected11 = torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + -2.0, + ], + [ + 0.0, + 0.0, + 2.0, + 0.0, + ], + [ + 0.0, + 2.0, + 0.0, + 0.0, + ], + [ + -2.0, + 0.0, + 0.0, + 0.0, + ], + ] + ) + expected12 = torch.tensor( + [[-1.0, 3.0], [-3.0, 1.0], [-3.0, 1.0], [-7.0, -3.0]] + ) + expected22 = torch.tensor( + [ + [0.0, 2.0], + [2.0, 0.0], + ] + ) + mat11 = compute_minkowski_distance_mat(vec1, vec1, c=1.0) + mat12 = compute_minkowski_distance_mat(vec1, vec2, c=1.0) + mat22 = compute_minkowski_distance_mat(vec2, vec2, c=1.0) + + assert torch.allclose(mat11, expected11) + assert torch.allclose(mat12, expected12) + assert torch.allclose(mat22, expected22) + + +def test_minkowski_knn_edges() -> None: + """Testing the minkowski knn edge definition.""" + data = Data( + x=torch.tensor( + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 1.0, + 0.0, + 0.0, + 1.0, + ], + [ + 1.0, + 0.0, + 1.0, + 2.0, + ], + ] + ) + ) + edge_index = MinkowskiKNNEdges( + nb_nearest_neighbours=2, + c=1.0, + )(data).edge_index + expected = torch.tensor( + [ + [0, 0, 1, 1, 2, 2, 3, 3], + [1, 2, 0, 3, 0, 3, 1, 2], + ] + ) + assert torch.allclose(edge_index[0], expected[0]) + + # Allow for "permutation of connections" in edge_index[1] + assert torch.allclose( + edge_index[1, [0, 1]], expected[1, [0, 1]] + ) or torch.allclose(edge_index[1, [0, 1]], expected[1, [1, 0]]) + assert torch.allclose( + edge_index[1, [2, 3]], expected[1, [2, 3]] + ) or torch.allclose(edge_index[1, [2, 3]], expected[1, [3, 2]]) + assert torch.allclose( + edge_index[1, [4, 5]], expected[1, [4, 5]] + ) or torch.allclose(edge_index[1, [4, 5]], expected[1, [5, 4]]) + assert torch.allclose( + edge_index[1, [6, 7]], expected[1, [6, 7]] + ) or torch.allclose(edge_index[1, [6, 7]], expected[1, [7, 6]]) From c5dbb1d69bea106dde478cda575d59353bbc36b9 Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Tue, 12 Dec 2023 21:05:38 +0100 Subject: [PATCH 6/7] Swapped row/col --- src/graphnet/models/graphs/edges/minkowski.py | 4 ++-- tests/models/test_minkowski.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 00fb440d7..5d1134ec5 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -78,7 +78,7 @@ def _construct_edges(self, graph: Data) -> Data: ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) - row += [ + col += [ c for c in range(num_points) for _ in range(count, count + num_edges) @@ -89,7 +89,7 @@ def _construct_edges(self, graph: Data) -> Data: ) # self-loops distance_sorted = distance_mat.argsort(dim=1) distance_sorted += count # offset by previous events - col += distance_sorted[:num_edges].flatten().tolist() + row += distance_sorted[:num_edges].flatten().tolist() count += num_points graph.edge_index = torch.tensor( diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py index 4a42ca25b..761d4f993 100644 --- a/tests/models/test_minkowski.py +++ b/tests/models/test_minkowski.py @@ -139,8 +139,8 @@ def test_minkowski_knn_edges() -> None: )(data).edge_index expected = torch.tensor( [ - [0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2], + [0, 0, 1, 1, 2, 2, 3, 3], ] ) assert torch.allclose(edge_index[0], expected[0]) From 34bc2e0eff92bd6db00d286cf8216b34b596cda3 Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Tue, 12 Dec 2023 21:15:13 +0100 Subject: [PATCH 7/7] Swapped row/col for assertions to match previous commit --- tests/models/test_minkowski.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/test_minkowski.py b/tests/models/test_minkowski.py index 761d4f993..af66196cf 100644 --- a/tests/models/test_minkowski.py +++ b/tests/models/test_minkowski.py @@ -143,18 +143,18 @@ def test_minkowski_knn_edges() -> None: [0, 0, 1, 1, 2, 2, 3, 3], ] ) - assert torch.allclose(edge_index[0], expected[0]) + assert torch.allclose(edge_index[1], expected[1]) # Allow for "permutation of connections" in edge_index[1] assert torch.allclose( - edge_index[1, [0, 1]], expected[1, [0, 1]] + edge_index[0, [0, 1]], expected[0, [0, 1]] ) or torch.allclose(edge_index[1, [0, 1]], expected[1, [1, 0]]) assert torch.allclose( - edge_index[1, [2, 3]], expected[1, [2, 3]] + edge_index[0, [2, 3]], expected[0, [2, 3]] ) or torch.allclose(edge_index[1, [2, 3]], expected[1, [3, 2]]) assert torch.allclose( - edge_index[1, [4, 5]], expected[1, [4, 5]] + edge_index[0, [4, 5]], expected[0, [4, 5]] ) or torch.allclose(edge_index[1, [4, 5]], expected[1, [5, 4]]) assert torch.allclose( - edge_index[1, [6, 7]], expected[1, [6, 7]] + edge_index[0, [6, 7]], expected[0, [6, 7]] ) or torch.allclose(edge_index[1, [6, 7]], expected[1, [7, 6]])