From 68634cc52e21f2a595b6c36af6fb5f756271554c Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Sun, 29 Dec 2024 13:25:32 -0500 Subject: [PATCH] Improvements to GRIT arguments, added new position encodings, fixed pickling --- src/graphnet/models/components/embedding.py | 35 +++++ src/graphnet/models/components/layers.py | 25 +++- src/graphnet/models/gnn/grit.py | 60 +++++---- src/graphnet/models/graphs/__init__.py | 8 +- src/graphnet/models/graphs/edges/edges.py | 12 ++ src/graphnet/models/graphs/graphs.py | 134 +++++++++++++++++++- src/graphnet/models/utils.py | 72 ++++++++++- 7 files changed, 315 insertions(+), 31 deletions(-) diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py index 3c02b6826..307fcb8c6 100644 --- a/src/graphnet/models/components/embedding.py +++ b/src/graphnet/models/components/embedding.py @@ -367,3 +367,38 @@ def forward(self, data: Data) -> Data: data.edge_index, data.edge_attr = out_idx, out_val return data + + +class RWSELinearNodeEncoder(LightningModule): + """Random walk structural node encoding.""" + + def __init__( + self, + emb_dim: int, + out_dim: int, + use_bias: bool = False, + ): + """Construct `RWSELinearEdgeEncoder`. + + Args: + emb_dim: Embedding dimension. + out_dim: Output dimension. + use_bias: Apply bias to linear layer. + """ + super().__init__() + + self.emb_dim = emb_dim + self.out_dim = out_dim + + self.encoder = nn.Linear(emb_dim, out_dim, bias=use_bias) + + def forward(self, data: Data) -> Data: + """Forward pass.""" + rwse = data.rwse + x = data.x + + rwse = self.encoder(rwse) + + data.x = torch.cat((x, rwse), dim=1) + + return data diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index 011128d52..33fcb94a5 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -6,7 +6,11 @@ import torch.nn as nn from torch.functional import Tensor from torch_geometric.nn import EdgeConv -from torch_geometric.nn.pool import knn_graph, global_add_pool +from torch_geometric.nn.pool import ( + knn_graph, + global_mean_pool, + global_add_pool, +) from torch_geometric.typing import Adj, PairTensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset @@ -893,14 +897,13 @@ def forward(self, data: Data) -> Data: x = self.fc1_x(x) if e_attn_out is not None: e = e_attn_out.flatten(1) - # TODO: Make this a nn.Dropout in initialization -PW e = self.dropout2(e) e = self.fc1_e(e) if self.residual: if self.rezero: x = x * self.alpha1_x - x = x_attn_residual + x # residual connection + x = x_attn_residual + x if e is not None: if self.rezero: @@ -946,27 +949,38 @@ class SANGraphHead(LightningModule): def __init__( self, dim_in: int, + dim_out: int = 1, L: int = 2, activation: nn.Module = nn.ReLU, + pooling: str = "mean", ): """Construct `SANGraphHead`. Args: dim_in: Input dimension. + dim_out: Output dimension. L: Number of hidden layers. activation: Activation function. + pooling: Pooling method. """ super().__init__() - self.pooling_fun = global_add_pool + if pooling == "mean": + self.pooling_fun = global_mean_pool + elif pooling == "add": + self.pooling_fun = global_add_pool + else: + raise RuntimeError("Currently supports only 'add' or 'mean'.") fc_layers = [ nn.Linear(dim_in // 2**n, dim_in // 2 ** (n + 1), bias=True) for n in range(L) ] + assert dim_in // 2**L >= dim_out, "Too much dim reduction!" + fc_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True)) self.fc_layers = nn.ModuleList(fc_layers) self.L = L self.activation = activation() - self.dim_out = dim_in // 2**L + self.dim_out = dim_out def forward(self, data: Data) -> Tensor: """Forward Pass.""" @@ -974,6 +988,7 @@ def forward(self, data: Data) -> Tensor: for i in range(self.L): graph_emb = self.fc_layers[i](graph_emb) graph_emb = self.activation(graph_emb) + graph_emb = self.fc_layers[self.L](graph_emb) # Original code applied a final linear layer to project to dim_out, # but we will let the Task layer do that. return graph_emb diff --git a/src/graphnet/models/gnn/grit.py b/src/graphnet/models/gnn/grit.py index a05dc0494..c5edf3d0b 100644 --- a/src/graphnet/models/gnn/grit.py +++ b/src/graphnet/models/gnn/grit.py @@ -9,7 +9,6 @@ """ import torch.nn as nn - from torch import Tensor from torch_geometric.data import Data @@ -24,6 +23,7 @@ RRWPLinearNodeEncoder, LinearNodeEncoder, LinearEdgeEncoder, + RWSELinearNodeEncoder, ) @@ -38,6 +38,7 @@ def __init__( self, nb_inputs: int, hidden_dim: int, + nb_outputs: int = 1, ksteps: int = 21, n_layers: int = 10, n_heads: int = 8, @@ -56,13 +57,15 @@ def __init__( enable_edge_transform: bool = True, pred_head_layers: int = 2, pred_head_activation: nn.Module = nn.ReLU, + pred_head_pooling: str = "mean", + position_encoding: str = "NoPE", ): """Construct `GRIT` model. Args: nb_inputs: Number of inputs. hidden_dim: Size of hidden dimension. - dim_out: Size of output dimension. + nb_outputs: Size of output dimension. ksteps: Number of random walk steps. n_layers: Number of GRIT layers. n_heads: Number of heads in MHA. @@ -82,20 +85,36 @@ def __init__( enable_edge_transform: Apply transformation to edges. pred_head_layers: Number of layers in the prediction head. pred_head_activation: Prediction head activation function. + pred_head_pooling: Pooling function to use for the prediction head, + either "mean" (default) or "add". + position_encoding: Method of position encoding. """ - super().__init__(nb_inputs, hidden_dim // 2**pred_head_layers) - - self.node_encoder = LinearNodeEncoder(nb_inputs, hidden_dim) - self.edge_encoder = LinearEdgeEncoder(hidden_dim) - - self.rrwp_abs_encoder = RRWPLinearNodeEncoder(ksteps, hidden_dim) - self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( - ksteps, - hidden_dim, - pad_to_full_graph=pad_to_full_graph, - add_node_attr_as_self_loop=add_node_attr_as_self_loop, - fill_value=fill_value, - ) + super().__init__(nb_inputs, nb_outputs) + self.position_encoding = position_encoding.lower() + if self.position_encoding == "nope": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim), + LinearEdgeEncoder(hidden_dim), + ] + elif self.position_encoding == "rrwp": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim), + LinearEdgeEncoder(hidden_dim), + RRWPLinearNodeEncoder(ksteps, hidden_dim), + RRWPLinearEdgeEncoder( + ksteps, + hidden_dim, + pad_to_full_graph=pad_to_full_graph, + add_node_attr_as_self_loop=add_node_attr_as_self_loop, + fill_value=fill_value, + ), + ] + elif self.position_encoding == "rwse": + encoders = [ + LinearNodeEncoder(nb_inputs, hidden_dim - (ksteps - 1)), + RWSELinearNodeEncoder(ksteps - 1, hidden_dim), + ] + self.encoders = nn.ModuleList(encoders) layers = [] for _ in range(n_layers): @@ -120,19 +139,16 @@ def __init__( self.layers = nn.ModuleList(layers) self.head = SANGraphHead( dim_in=hidden_dim, + dim_out=nb_outputs, L=pred_head_layers, activation=pred_head_activation, + pooling=pred_head_pooling, ) def forward(self, x: Data) -> Tensor: """Forward pass.""" - # Apply linear layers to node/edge features - x = self.node_encoder(x) - x = self.edge_encoder(x) - - # Encode with RRWP - x = self.rrwp_abs_encoder(x) - x = self.rrwp_rel_encoder(x) + for encoder in self.encoders: + x = encoder(x) # Apply GRIT layers for layer in self.layers: diff --git a/src/graphnet/models/graphs/__init__.py b/src/graphnet/models/graphs/__init__.py index 49c57d559..77387b607 100644 --- a/src/graphnet/models/graphs/__init__.py +++ b/src/graphnet/models/graphs/__init__.py @@ -6,4 +6,10 @@ """ from .graph_definition import GraphDefinition -from .graphs import KNNGraph, EdgelessGraph, KNNGraphRRWP +from .graphs import ( + KNNGraph, + EdgelessGraph, + KNNGraphRRWP, + KNNGraphRWSE, + KNNGraphNoPE, +) diff --git a/src/graphnet/models/graphs/edges/edges.py b/src/graphnet/models/graphs/edges/edges.py index d01e01d3a..2ce43fc68 100644 --- a/src/graphnet/models/graphs/edges/edges.py +++ b/src/graphnet/models/graphs/edges/edges.py @@ -6,6 +6,8 @@ import torch from torch_geometric.nn import knn_graph, radius_graph from torch_geometric.data import Data +from torch_geometric.utils import to_undirected +from torch_geometric.utils.num_nodes import maybe_num_nodes from graphnet.models.utils import calculate_distance_matrix from graphnet.models import Model @@ -111,6 +113,16 @@ def __init__( def _construct_edges(self, graph: Data) -> Data: """Define K-NN edges.""" graph = super()._construct_edges(graph) + + if graph.edge_index.numel() == 0: # Check if edge_index is empty + num_nodes = graph.num_nodes + self_loops = torch.arange(num_nodes).repeat(2, 1) + graph.edge_index = self_loops + + graph.num_nodes = maybe_num_nodes(graph.edge_index) + graph.edge_index = to_undirected( + graph.edge_index, num_nodes=graph.num_nodes + ) position_data = graph.x[:, self._columns] src, tgt = graph.edge_index diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index f946353e1..a2b393639 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -15,7 +15,7 @@ KNNDistanceEdges, ) from graphnet.models.graphs.nodes import NodeDefinition, NodesAsPulses -from graphnet.models.utils import add_full_rrwp +from graphnet.models.utils import add_full_rrwp, get_rw_landing_probs class KNNGraph(GraphDefinition): @@ -176,3 +176,135 @@ def forward( # type: ignore graph = add_full_rrwp(graph, walk_length=self.walk_length) return graph + + +class KNNGraphRWSE(GraphDefinition): + """KNN Graph with random walk structural encoding.""" + + def __init__( + self, + detector: Detector, + node_definition: Optional[NodeDefinition] = None, + edge_definition: Optional[EdgeDefinition] = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + walk_length: int = 8, + **kwargs: Any, + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + edge_definition: Definition of edges in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. + walk_length: number of steps for the random walk. + Defaults to 8. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=edge_definition + or KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + **kwargs, + ) + self.walk_length = walk_length + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + **kwargs, + ) -> Data: + """Forward pass.""" + graph = super().forward(input_features, input_feature_names, **kwargs) + ksteps = torch.arange(1, self.walk_length) + graph.rwse = get_rw_landing_probs( + ksteps=ksteps, edge_index=graph.edge_index, edge_weight=None + ) + return graph + + +class KNNGraphNoPE(GraphDefinition): + """KNN Graph with edge distances and no positional encoding.""" + + def __init__( + self, + detector: Detector, + node_definition: Optional[NodeDefinition] = None, + edge_definition: Optional[EdgeDefinition] = None, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + **kwargs: Any, + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + edge_definition: Definition of edges in the graph. + input_feature_names: Name of input feature columns. + dtype: data type for node features. + perturbation_dict: Dictionary mapping a feature name to a standard + deviation according to which the values for this + feature should be randomly perturbed. Defaults + to None. + seed: seed or Generator used to randomly sample perturbations. + Defaults to None. + nb_nearest_neighbours: Number of edges for each node. + Defaults to 8. + columns: node feature columns used for distance calculation. + Defaults to [0, 1, 2]. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition or NodesAsPulses(), + edge_definition=edge_definition + or KNNDistanceEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + input_feature_names=input_feature_names, + perturbation_dict=perturbation_dict, + seed=seed, + **kwargs, + ) + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + **kwargs, + ) -> Data: + """Forward pass.""" + graph = super().forward(input_features, input_feature_names, **kwargs) + return graph diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index d70eb0687..4da1d048c 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -7,9 +7,10 @@ from torch_geometric.nn import knn_graph from torch_geometric.data import Batch, Data -from torch_geometric.utils import homophily, degree +from torch_geometric.utils import homophily, degree, to_dense_adj +from torch_geometric.utils.num_nodes import maybe_num_nodes -from torch_scatter import scatter +from torch_scatter import scatter, scatter_add from torch_sparse import SparseTensor @@ -266,3 +267,70 @@ def get_log_deg(data: Data) -> Tensor: log_deg = torch.log(deg + 1) log_deg = log_deg.view(data.num_nodes, 1) return log_deg + + +def get_rw_landing_probs( + ksteps: List, + edge_index: Tensor, + edge_weight: Tensor = None, + num_nodes: Optional[int] = None, + space_dim: int = 0, +) -> Tensor: + """Compute Random Walk landing probabilities for given list of K steps. + + Original code: + https://github.com/ETH-DISCO/Benchmarking-PEs + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number + of steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + print(edge_index.shape) + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source = edge_index[0] + # dest = edge_index[1] + + # Out degrees + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) + deg_inv = deg.pow(-1.0) + deg_inv.masked_fill_(deg_inv == float("inf"), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + # 1 x (Num nodes) x (Num nodes) + P = torch.diag(deg_inv) @ to_dense_adj( + edge_index, max_num_nodes=num_nodes + ) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append( + torch.diagonal(Pk, dim1=-2, dim2=-1) * (k ** (space_dim / 2)) + ) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append( + torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) + * (k ** (space_dim / 2)) + ) + + # (Num nodes) x (K steps) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) + return rw_landing