Skip to content

Commit

Permalink
Improvements to GRIT arguments, added new position encodings, fixed p…
Browse files Browse the repository at this point in the history
…ickling
  • Loading branch information
pweigel committed Dec 29, 2024
1 parent 48d5e85 commit 68634cc
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 31 deletions.
35 changes: 35 additions & 0 deletions src/graphnet/models/components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,38 @@ def forward(self, data: Data) -> Data:

data.edge_index, data.edge_attr = out_idx, out_val
return data


class RWSELinearNodeEncoder(LightningModule):
"""Random walk structural node encoding."""

def __init__(
self,
emb_dim: int,
out_dim: int,
use_bias: bool = False,
):
"""Construct `RWSELinearEdgeEncoder`.
Args:
emb_dim: Embedding dimension.
out_dim: Output dimension.
use_bias: Apply bias to linear layer.
"""
super().__init__()

self.emb_dim = emb_dim
self.out_dim = out_dim

self.encoder = nn.Linear(emb_dim, out_dim, bias=use_bias)

def forward(self, data: Data) -> Data:
"""Forward pass."""
rwse = data.rwse
x = data.x

rwse = self.encoder(rwse)

data.x = torch.cat((x, rwse), dim=1)

return data
25 changes: 20 additions & 5 deletions src/graphnet/models/components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch.nn as nn
from torch.functional import Tensor
from torch_geometric.nn import EdgeConv
from torch_geometric.nn.pool import knn_graph, global_add_pool
from torch_geometric.nn.pool import (
knn_graph,
global_mean_pool,
global_add_pool,
)
from torch_geometric.typing import Adj, PairTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
Expand Down Expand Up @@ -893,14 +897,13 @@ def forward(self, data: Data) -> Data:
x = self.fc1_x(x)
if e_attn_out is not None:
e = e_attn_out.flatten(1)
# TODO: Make this a nn.Dropout in initialization -PW
e = self.dropout2(e)
e = self.fc1_e(e)

if self.residual:
if self.rezero:
x = x * self.alpha1_x
x = x_attn_residual + x # residual connection
x = x_attn_residual + x

if e is not None:
if self.rezero:
Expand Down Expand Up @@ -946,34 +949,46 @@ class SANGraphHead(LightningModule):
def __init__(
self,
dim_in: int,
dim_out: int = 1,
L: int = 2,
activation: nn.Module = nn.ReLU,
pooling: str = "mean",
):
"""Construct `SANGraphHead`.
Args:
dim_in: Input dimension.
dim_out: Output dimension.
L: Number of hidden layers.
activation: Activation function.
pooling: Pooling method.
"""
super().__init__()
self.pooling_fun = global_add_pool
if pooling == "mean":
self.pooling_fun = global_mean_pool
elif pooling == "add":
self.pooling_fun = global_add_pool
else:
raise RuntimeError("Currently supports only 'add' or 'mean'.")

fc_layers = [
nn.Linear(dim_in // 2**n, dim_in // 2 ** (n + 1), bias=True)
for n in range(L)
]
assert dim_in // 2**L >= dim_out, "Too much dim reduction!"
fc_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True))
self.fc_layers = nn.ModuleList(fc_layers)
self.L = L
self.activation = activation()
self.dim_out = dim_in // 2**L
self.dim_out = dim_out

def forward(self, data: Data) -> Tensor:
"""Forward Pass."""
graph_emb = self.pooling_fun(data.x, data.batch)
for i in range(self.L):
graph_emb = self.fc_layers[i](graph_emb)
graph_emb = self.activation(graph_emb)
graph_emb = self.fc_layers[self.L](graph_emb)
# Original code applied a final linear layer to project to dim_out,
# but we will let the Task layer do that.
return graph_emb
60 changes: 38 additions & 22 deletions src/graphnet/models/gnn/grit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import torch.nn as nn

from torch import Tensor
from torch_geometric.data import Data

Expand All @@ -24,6 +23,7 @@
RRWPLinearNodeEncoder,
LinearNodeEncoder,
LinearEdgeEncoder,
RWSELinearNodeEncoder,
)


Expand All @@ -38,6 +38,7 @@ def __init__(
self,
nb_inputs: int,
hidden_dim: int,
nb_outputs: int = 1,
ksteps: int = 21,
n_layers: int = 10,
n_heads: int = 8,
Expand All @@ -56,13 +57,15 @@ def __init__(
enable_edge_transform: bool = True,
pred_head_layers: int = 2,
pred_head_activation: nn.Module = nn.ReLU,
pred_head_pooling: str = "mean",
position_encoding: str = "NoPE",
):
"""Construct `GRIT` model.
Args:
nb_inputs: Number of inputs.
hidden_dim: Size of hidden dimension.
dim_out: Size of output dimension.
nb_outputs: Size of output dimension.
ksteps: Number of random walk steps.
n_layers: Number of GRIT layers.
n_heads: Number of heads in MHA.
Expand All @@ -82,20 +85,36 @@ def __init__(
enable_edge_transform: Apply transformation to edges.
pred_head_layers: Number of layers in the prediction head.
pred_head_activation: Prediction head activation function.
pred_head_pooling: Pooling function to use for the prediction head,
either "mean" (default) or "add".
position_encoding: Method of position encoding.
"""
super().__init__(nb_inputs, hidden_dim // 2**pred_head_layers)

self.node_encoder = LinearNodeEncoder(nb_inputs, hidden_dim)
self.edge_encoder = LinearEdgeEncoder(hidden_dim)

self.rrwp_abs_encoder = RRWPLinearNodeEncoder(ksteps, hidden_dim)
self.rrwp_rel_encoder = RRWPLinearEdgeEncoder(
ksteps,
hidden_dim,
pad_to_full_graph=pad_to_full_graph,
add_node_attr_as_self_loop=add_node_attr_as_self_loop,
fill_value=fill_value,
)
super().__init__(nb_inputs, nb_outputs)
self.position_encoding = position_encoding.lower()
if self.position_encoding == "nope":
encoders = [
LinearNodeEncoder(nb_inputs, hidden_dim),
LinearEdgeEncoder(hidden_dim),
]
elif self.position_encoding == "rrwp":
encoders = [
LinearNodeEncoder(nb_inputs, hidden_dim),
LinearEdgeEncoder(hidden_dim),
RRWPLinearNodeEncoder(ksteps, hidden_dim),
RRWPLinearEdgeEncoder(
ksteps,
hidden_dim,
pad_to_full_graph=pad_to_full_graph,
add_node_attr_as_self_loop=add_node_attr_as_self_loop,
fill_value=fill_value,
),
]
elif self.position_encoding == "rwse":
encoders = [
LinearNodeEncoder(nb_inputs, hidden_dim - (ksteps - 1)),
RWSELinearNodeEncoder(ksteps - 1, hidden_dim),
]
self.encoders = nn.ModuleList(encoders)

layers = []
for _ in range(n_layers):
Expand All @@ -120,19 +139,16 @@ def __init__(
self.layers = nn.ModuleList(layers)
self.head = SANGraphHead(
dim_in=hidden_dim,
dim_out=nb_outputs,
L=pred_head_layers,
activation=pred_head_activation,
pooling=pred_head_pooling,
)

def forward(self, x: Data) -> Tensor:
"""Forward pass."""
# Apply linear layers to node/edge features
x = self.node_encoder(x)
x = self.edge_encoder(x)

# Encode with RRWP
x = self.rrwp_abs_encoder(x)
x = self.rrwp_rel_encoder(x)
for encoder in self.encoders:
x = encoder(x)

# Apply GRIT layers
for layer in self.layers:
Expand Down
8 changes: 7 additions & 1 deletion src/graphnet/models/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@
"""

from .graph_definition import GraphDefinition
from .graphs import KNNGraph, EdgelessGraph, KNNGraphRRWP
from .graphs import (
KNNGraph,
EdgelessGraph,
KNNGraphRRWP,
KNNGraphRWSE,
KNNGraphNoPE,
)
12 changes: 12 additions & 0 deletions src/graphnet/models/graphs/edges/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch_geometric.nn import knn_graph, radius_graph
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.utils.num_nodes import maybe_num_nodes

from graphnet.models.utils import calculate_distance_matrix
from graphnet.models import Model
Expand Down Expand Up @@ -111,6 +113,16 @@ def __init__(
def _construct_edges(self, graph: Data) -> Data:
"""Define K-NN edges."""
graph = super()._construct_edges(graph)

if graph.edge_index.numel() == 0: # Check if edge_index is empty
num_nodes = graph.num_nodes
self_loops = torch.arange(num_nodes).repeat(2, 1)
graph.edge_index = self_loops

graph.num_nodes = maybe_num_nodes(graph.edge_index)
graph.edge_index = to_undirected(
graph.edge_index, num_nodes=graph.num_nodes
)
position_data = graph.x[:, self._columns]

src, tgt = graph.edge_index
Expand Down
Loading

0 comments on commit 68634cc

Please sign in to comment.