Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds an improved implementation of DynEdge #596

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/graphnet/models/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .dynedge import DynEdge
from .dynedge_jinst import DynEdgeJINST
from .dynedge_kaggle_tito import DynEdgeTITO
from .odynedge import ODynEdge
332 changes: 332 additions & 0 deletions src/graphnet/models/gnn/odynedge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
"""Optimized implementation of the DynEdge GNN model architecture."""
from typing import List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, LongTensor
from torch_geometric.data import Data
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

from graphnet.models.components.layers import DynEdgeConv
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily

GLOBAL_POOLINGS = {
"min": scatter_min,
"max": scatter_max,
"sum": scatter_sum,
"mean": scatter_mean,
}


class ODynEdge(GNN):
"""DynEdge (dynamical edge convolutional) model."""

@save_model_config
def __init__(
self,
nb_inputs: int,
*,
nb_neighbours: int = 8,
features_subset: Optional[Union[List[int], slice]] = None,
dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
post_processing_layer_sizes: Optional[List[int]] = None,
readout_layer_sizes: Optional[List[int]] = None,
global_pooling_schemes: Optional[Union[str, List[str]]] = None,
add_global_variables_after_pooling: bool = False,
):
"""Construct `DynEdge`.

Args:
nb_inputs: Number of input features on each node.
nb_neighbours: Number of neighbours to used in the k-nearest
neighbour clustering which is performed after each (dynamical)
edge convolution.
features_subset: The subset of latent features on each node that
are used as metric dimensions when performing the k-nearest
neighbours clustering. Defaults to [0,1,2].
dynedge_layer_sizes: The layer sizes, or latent feature dimenions,
used in the `DynEdgeConv` layer. Each entry in
`dynedge_layer_sizes` corresponds to a single `DynEdgeConv`
layer; the integers in the corresponding tuple corresponds to
the layer sizes in the multi-layer perceptron (MLP) that is
applied within each `DynEdgeConv` layer. That is, a list of
size-two tuples means that all `DynEdgeConv` layers contain a
two-layer MLP.
Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].
post_processing_layer_sizes: Hidden layer sizes in the MLP
following the skip-concatenation of the outputs of each
`DynEdgeConv` layer. Defaults to [336, 256].
readout_layer_sizes: Hidden layer sizes in the MLP following the
post-processing _and_ optional global pooling. As this is the
last layer(s) in the model, the last layer in the read-out
yields the output of the `DynEdge` model. Defaults to [128,].
global_pooling_schemes: The list global pooling schemes to use.
Options are: "min", "max", "mean", and "sum".
add_global_variables_after_pooling: Whether to add global variables
after global pooling. The alternative is to added (distribute)
them to the individual nodes before any convolutional
operations.
"""
# Latent feature subset for computing nearest neighbours in DynEdge.
if features_subset is None:
features_subset = slice(0, 3)

# DynEdge layer sizes
if dynedge_layer_sizes is None:
dynedge_layer_sizes = [
(
128,
256,
),
(
336,
256,
),
(
336,
256,
),
(
336,
256,
),
]

assert isinstance(dynedge_layer_sizes, list)
assert len(dynedge_layer_sizes)
assert all(isinstance(sizes, tuple) for sizes in dynedge_layer_sizes)
assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes)
assert all(
all(size > 0 for size in sizes) for sizes in dynedge_layer_sizes
)

self._dynedge_layer_sizes = dynedge_layer_sizes

# Post-processing layer sizes
if post_processing_layer_sizes is None:
post_processing_layer_sizes = [
336,
256,
]

assert isinstance(post_processing_layer_sizes, list)
assert len(post_processing_layer_sizes)
assert all(size > 0 for size in post_processing_layer_sizes)

self._post_processing_layer_sizes = post_processing_layer_sizes

# Read-out layer sizes
if readout_layer_sizes is None:
readout_layer_sizes = [
128,
]

assert isinstance(readout_layer_sizes, list)
assert len(readout_layer_sizes)
assert all(size > 0 for size in readout_layer_sizes)

self._readout_layer_sizes = readout_layer_sizes

# Global pooling scheme(s)
if isinstance(global_pooling_schemes, str):
global_pooling_schemes = [global_pooling_schemes]

if isinstance(global_pooling_schemes, list):
for pooling_scheme in global_pooling_schemes:
assert (
pooling_scheme in GLOBAL_POOLINGS
), f"Global pooling scheme {pooling_scheme} not supported."
else:
assert global_pooling_schemes is None

self._global_pooling_schemes = global_pooling_schemes

if add_global_variables_after_pooling:
assert self._global_pooling_schemes, (
"No global pooling schemes were request, so cannot add global"
" variables after pooling."
)
self._add_global_variables_after_pooling = (
add_global_variables_after_pooling
)

# Base class constructor
super().__init__(nb_inputs, self._readout_layer_sizes[-1])

# Remaining member variables()
self._activation = torch.nn.LeakyReLU()
self._nb_inputs = nb_inputs
self._nb_global_variables = 5 + nb_inputs
self._nb_neighbours = nb_neighbours
self._features_subset = features_subset

self._construct_layers()

def _construct_layers(self) -> None:
"""Construct layers (torch.nn.Modules)."""
# Convolutional operations
nb_input_features = self._nb_inputs
if not self._add_global_variables_after_pooling:
nb_input_features += self._nb_global_variables

self._conv_layers = torch.nn.ModuleList()
self._first_post_processing = torch.nn.ModuleList(
[
torch.nn.Linear(
nb_input_features, self._post_processing_layer_sizes[0]
)
]
)
nb_latent_features = nb_input_features
for sizes in self._dynedge_layer_sizes:
layers = []
layer_sizes = [nb_latent_features] + list(sizes)
for ix, (nb_in, nb_out) in enumerate(
zip(layer_sizes[:-1], layer_sizes[1:])
):
if ix == 0:
nb_in *= 2
layers.append(torch.nn.Linear(nb_in, nb_out))
layers.append(self._activation)

conv_layer = DynEdgeConv(
torch.nn.Sequential(*layers),
aggr="add",
nb_neighbors=self._nb_neighbours,
features_subset=self._features_subset,
)
self._conv_layers.append(conv_layer)

nb_latent_features = nb_out
self._first_post_processing.append(
torch.nn.Linear(
nb_out, self._post_processing_layer_sizes[0], bias=False
)
)

# Post-processing operations
post_processing_layers = []
layer_sizes = list(self._post_processing_layer_sizes)

for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
post_processing_layers.append(torch.nn.Linear(nb_in, nb_out))
post_processing_layers.append(self._activation)

self._post_processing = torch.nn.Sequential(*post_processing_layers)

# Read-out operations
nb_poolings = (
len(self._global_pooling_schemes)
if self._global_pooling_schemes
else 1
)
nb_latent_features = nb_out * nb_poolings
if self._add_global_variables_after_pooling:
nb_latent_features += self._nb_global_variables

readout_layers = []
layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes)
for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
readout_layers.append(torch.nn.Linear(nb_in, nb_out))
readout_layers.append(self._activation)

self._readout = torch.nn.Sequential(*readout_layers)

def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor:
"""Perform global pooling."""
assert self._global_pooling_schemes
pooled = []
for pooling_scheme in self._global_pooling_schemes:
pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
pooled_x = pooling_fn(x, index=batch, dim=0)
if isinstance(pooled_x, tuple) and len(pooled_x) == 2:
# `scatter_{min,max}`, which return also an argument, vs.
# `scatter_{mean,sum}`
pooled_x, _ = pooled_x
pooled.append(pooled_x)

return torch.cat(pooled, dim=1)

def _calculate_global_variables(
self,
x: Tensor,
edge_index: LongTensor,
batch: LongTensor,
*additional_attributes: Tensor,
) -> Tensor:
"""Calculate global variables."""
# Calculate homophily (scalar variables)
h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)

# Calculate mean features
global_means = scatter_mean(x, batch, dim=0)

# Add global variables
global_variables = torch.cat(
[
global_means,
h_x,
h_y,
h_z,
h_t,
]
+ [attr.unsqueeze(dim=1) for attr in additional_attributes],
dim=1,
)

return global_variables

def forward(self, data: Data) -> Tensor:
"""Apply learnable forward pass."""
# Convenience variables
x, edge_index, batch = data.x, data.edge_index, data.batch

global_variables = self._calculate_global_variables(
x,
edge_index,
batch,
torch.log10(data.n_pulses),
)

# Distribute global variables out to each node
if not self._add_global_variables_after_pooling:
distribute = (
batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0)
).type(torch.float)

global_variables_distributed = torch.sum(
distribute.unsqueeze(dim=2)
* global_variables.unsqueeze(dim=0),
dim=1,
)

x = torch.cat((x, global_variables_distributed), dim=1)

# DynEdge-convolutions
out = self._first_post_processing[0](x)
for conv_layer, linear_layer in zip(
self._conv_layers, self._first_post_processing[1:]
):
x, edge_index = conv_layer(x, edge_index, batch)
out += linear_layer(x)

# Post-processing
x = self._post_processing(self._activation(out))

# (Optional) Global pooling
if self._global_pooling_schemes:
x = self._global_pooling(x, batch=batch)
if self._add_global_variables_after_pooling:
x = torch.cat(
[
x,
global_variables,
],
dim=1,
)

# Read-out
x = self._readout(x)

return x