-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Particlenet added and all km3net dependencies are deleted #740
Merged
Merged
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
"""Implementation of the ParticleNet GNN model architecture.""" | ||
from typing import List, Optional, Callable, Tuple, Union | ||
|
||
import torch | ||
from torch import Tensor, LongTensor | ||
from torch_geometric.data import Data | ||
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum | ||
|
||
from graphnet.models.components.layers import DynEdgeConv | ||
from graphnet.models.gnn.gnn import GNN | ||
|
||
GLOBAL_POOLINGS = { | ||
"min": scatter_min, | ||
"max": scatter_max, | ||
"sum": scatter_sum, | ||
"mean": scatter_mean, | ||
} | ||
|
||
|
||
class ParticleNeT(GNN): | ||
"""ParticleNeT (dynamical edge convolutional) model. | ||
|
||
Inspired by: https://arxiv.org/abs/1902.08570 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
nb_inputs: int, | ||
*, | ||
nb_neighbours: int = 16, | ||
features_subset: Optional[Union[List[int], slice]] = None, | ||
dynamic: bool = True, | ||
dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = [ | ||
(64, 64, 64), | ||
(128, 128, 128), | ||
(256, 256, 256), | ||
], | ||
readout_layer_sizes: Optional[List[int]] = [256], | ||
global_pooling_schemes: Optional[Union[str, List[str]]] = "mean", | ||
activation_layer: Optional[str] = "relu", | ||
add_batchnorm_layer: bool = True, | ||
dropout_readout: float = 0.1, | ||
skip_readout: bool = False, | ||
): | ||
"""Construct `ParticleNeT`. | ||
|
||
Args: | ||
nb_inputs: Number of input features on each node. | ||
nb_neighbours: Number of neighbours to used in the k-nearest | ||
neighbour clustering which is performed after each (dynamical) | ||
edge convolution. | ||
features_subset: The subset of latent features on each node that | ||
are used as metric dimensions when performing the k-nearest | ||
neighbours clustering. Defaults to [0,1,2]. | ||
dynamic: wether or not update the edges after every `DynEdgeConv` | ||
block. | ||
dynedge_layer_sizes: The layer sizes, or latent feature dimenions, | ||
used in the `DynEdgeConv` layer. Each entry in | ||
`dynedge_layer_sizes` corresponds to a single `DynEdgeConv` | ||
layer; the integers in the corresponding tuple corresponds to | ||
the layer sizes in the multi-layer perceptron (MLP) that is | ||
applied within each `DynEdgeConv` layer. That is, a list of | ||
size-three tuples means that all `DynEdgeConv` layers contain | ||
a three-layer MLP. | ||
Defaults to [(64, 64, 64), (128, 128, 128), (256, 256, 256)]. | ||
readout_layer_sizes: Hidden layer size in the MLP following the | ||
post-processing _and_ optional global pooling. As this is the | ||
last layer in the model, it yields the output of the `DynEdge` | ||
model. Defaults to [256,]. | ||
global_pooling_schemes: The list global pooling schemes to use. | ||
Options are: "min", "max", "mean", and "sum". | ||
Default to "mean". | ||
activation_layer: The activation function to use in the model. | ||
Default to "relu". | ||
add_batchnorm_layer: Whether to add a batch normalization layer after | ||
each linear layer. Default to True. | ||
dropout_readout: Dropout value to use in the readout layer(s). | ||
Default to 0.1. | ||
skip_readout: Whether to skip the readout layer(s). If `True`, the | ||
output of the last DynEdgeConv block is returned directly. | ||
""" | ||
# Latent feature subset for computing nearest neighbours in ParticleNeT. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line too long. See CodeClimate |
||
if features_subset is None: | ||
features_subset = slice(0, 3) | ||
|
||
# DynEdge layer sizes | ||
if dynedge_layer_sizes is None: | ||
dynedge_layer_sizes = [ | ||
(64, 64, 64), | ||
( | ||
128, | ||
128, | ||
128, | ||
), | ||
( | ||
256, | ||
256, | ||
256, | ||
), | ||
] | ||
|
||
dynedge_layer_sizes_check = [] | ||
for sizes in dynedge_layer_sizes: | ||
if isinstance(sizes, list): | ||
sizes = tuple(sizes) | ||
dynedge_layer_sizes_check.append(sizes) | ||
|
||
assert isinstance(dynedge_layer_sizes_check, list) | ||
assert len(dynedge_layer_sizes_check) | ||
assert all( | ||
isinstance(sizes, tuple) for sizes in dynedge_layer_sizes_check | ||
) | ||
assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes_check) | ||
assert all( | ||
all(size > 0 for size in sizes) | ||
for sizes in dynedge_layer_sizes_check | ||
) | ||
|
||
self._dynedge_layer_sizes = dynedge_layer_sizes_check | ||
|
||
# Read-out layer sizes | ||
if readout_layer_sizes is None: | ||
readout_layer_sizes = [ | ||
256, | ||
] | ||
|
||
assert isinstance(readout_layer_sizes, list) | ||
assert len(readout_layer_sizes) | ||
assert all(size > 0 for size in readout_layer_sizes) | ||
|
||
self._readout_layer_sizes = readout_layer_sizes | ||
|
||
# Global pooling scheme(s) | ||
if isinstance(global_pooling_schemes, str): | ||
global_pooling_schemes = [global_pooling_schemes] | ||
|
||
if isinstance(global_pooling_schemes, list): | ||
for pooling_scheme in global_pooling_schemes: | ||
assert ( | ||
pooling_scheme in GLOBAL_POOLINGS | ||
), f"Global pooling scheme {pooling_scheme} not supported." | ||
else: | ||
assert global_pooling_schemes is None | ||
|
||
self._global_pooling_schemes = global_pooling_schemes | ||
|
||
if activation_layer is None or activation_layer.lower() == "relu": | ||
activation_layer = torch.nn.ReLU() | ||
elif activation_layer.lower() == "gelu": | ||
activation_layer = torch.nn.GELU() | ||
else: | ||
raise ValueError( | ||
f"Activation layer {activation_layer} not supported." | ||
) | ||
|
||
# Base class constructor | ||
super().__init__(nb_inputs, self._readout_layer_sizes[-1]) | ||
|
||
# Remaining member variables() | ||
self._activation = activation_layer | ||
self._nb_inputs = nb_inputs | ||
self._nb_neighbours = nb_neighbours | ||
self._features_subset = features_subset | ||
self._dynamic = dynamic | ||
self._add_batchnorm_layer = add_batchnorm_layer | ||
self._dropout_readout = dropout_readout | ||
self._skip_readout = skip_readout | ||
|
||
self._construct_layers() | ||
|
||
# Builds the network | ||
def _construct_layers(self) -> None: | ||
"""Construct layers (torch.nn.Modules).""" | ||
# Convolutional operations | ||
nb_input_features = self._nb_inputs | ||
|
||
self._conv_layers = torch.nn.ModuleList() | ||
nb_latent_features = nb_input_features | ||
for sizes in self._dynedge_layer_sizes: | ||
layers = [] | ||
layer_sizes = [nb_latent_features] + list(sizes) | ||
for ix, (nb_in, nb_out) in enumerate( | ||
zip(layer_sizes[:-1], layer_sizes[1:]) | ||
): | ||
if ix == 0: | ||
nb_in *= 2 | ||
layers.append(torch.nn.Linear(nb_in, nb_out)) | ||
if self._add_batchnorm_layer: | ||
layers.append(torch.nn.BatchNorm1d(nb_out)) | ||
layers.append(self._activation) | ||
|
||
conv_layer = DynEdgeConv( | ||
torch.nn.Sequential(*layers), | ||
aggr="mean", | ||
nb_neighbors=self._nb_neighbours, | ||
features_subset=self._features_subset, | ||
) | ||
self._conv_layers.append(conv_layer) | ||
|
||
nb_latent_features = nb_out | ||
|
||
# Read-out operations | ||
nb_poolings = ( | ||
len(self._global_pooling_schemes) | ||
if self._global_pooling_schemes | ||
else 1 | ||
) | ||
nb_latent_features = nb_out * nb_poolings | ||
|
||
readout_layers = [] | ||
layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) | ||
for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): | ||
readout_layers.append(torch.nn.Linear(nb_in, nb_out)) | ||
readout_layers.append(self._activation) | ||
readout_layers.append(torch.nn.Dropout(self._dropout_readout)) | ||
|
||
self._readout = torch.nn.Sequential(*readout_layers) | ||
|
||
def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: | ||
"""Perform global pooling.""" | ||
assert self._global_pooling_schemes | ||
pooled = [] | ||
for pooling_scheme in self._global_pooling_schemes: | ||
pooling_fn = GLOBAL_POOLINGS[pooling_scheme] | ||
pooled_x = pooling_fn(x, index=batch, dim=0) | ||
if isinstance(pooled_x, tuple) and len(pooled_x) == 2: | ||
# `scatter_{min,max}`, which return also an argument, vs. | ||
# `scatter_{mean,sum}` | ||
pooled_x, _ = pooled_x | ||
pooled.append(pooled_x) | ||
|
||
return torch.cat(pooled, dim=1) | ||
|
||
def forward(self, data: Data) -> Tensor: | ||
"""Apply learnable forward pass.""" | ||
# Convenience variables | ||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||
|
||
# DynEdge-convolutions | ||
for conv_layer in self._conv_layers: | ||
if self._dynamic: | ||
x, edge_index = conv_layer(x, edge_index, batch) | ||
else: | ||
x, _ = conv_layer(x, edge_index, batch) | ||
|
||
# Read-out | ||
if not self._skip_readout: | ||
# (Optional) Global pooling | ||
if self._global_pooling_schemes: | ||
x = self._global_pooling(x, batch=batch) | ||
|
||
# Read-out | ||
x = self._readout(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
from .dynedge_kaggle_tito import DynEdgeTITO | ||
from .RNN_tito import RNN_TITO | ||
from .icemix import DeepIce | ||
from .ParticleNeT import ParticleNeT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line too long. See CodeClimate.