From d1b97d559a4a3b381f7378bf7d13811969f883cb Mon Sep 17 00:00:00 2001 From: ArturoLlorente <85907219+ArturoLlorente@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:37:02 +0200 Subject: [PATCH] Redefinition of Tito model (#611) * changes to have a more general GNN definition. All models used in the Tito solution can be replicated with this class. Small changes in the training example to fit new input variables * changes from PR #607 to solve config summarizing issue * added features_subset definition in gnn definition. target set to str --- examples/04_training/02_train_tito_model.py | 6 +- .../models/gnn/dynedge_kaggle_tito.py | 85 ++++++++++++------- 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/examples/04_training/02_train_tito_model.py b/examples/04_training/02_train_tito_model.py index 735dea055..eeffacfed 100644 --- a/examples/04_training/02_train_tito_model.py +++ b/examples/04_training/02_train_tito_model.py @@ -107,7 +107,11 @@ def main( # Building model gnn = DynEdgeTITO( nb_inputs=graph_definition.nb_outputs, + features_subset=[0, 1, 2, 3], + dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)], global_pooling_schemes=["max"], + use_global_features=True, + use_post_processing_layers=True, ) task = DirectionReconstructionWithKappa( hidden_size=gnn.nb_outputs, @@ -212,7 +216,7 @@ def main( "Name of feature to use as regression target (default: " "%(default)s)" ), - default=["direction"], + default="direction", ) parser.add_argument( diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 4a4662256..266739c8b 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -8,16 +8,16 @@ Solution by TITO. """ -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union import torch from torch import Tensor, LongTensor from torch_geometric.data import Data -from torch_geometric.utils import to_dense_batch from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum from graphnet.models.components.layers import DynTrans +from graphnet.utilities.config import save_model_config from graphnet.models.gnn.gnn import GNN from graphnet.models.utils import calculate_xyzt_homophily @@ -30,16 +30,19 @@ class DynEdgeTITO(GNN): - """DynEdge (dynamical edge convolutional) model.""" + """DynEdgeTITO (dynamical edge convolutional with Transformer) model.""" + @save_model_config def __init__( self, nb_inputs: int, features_subset: List[int] = None, dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, global_pooling_schemes: List[str] = ["max"], + use_global_features: bool = True, + use_post_processing_layers: bool = True, ): - """Construct `DynEdge`. + """Construct `DynEdgeTITO`. Args: nb_inputs: Number of input features on each node. @@ -48,10 +51,14 @@ def __init__( neighbours clustering. Defaults to [0,1,2,3]. dyntrans_layer_sizes: The layer sizes, or latent feature dimenions, used in the `DynTrans` layer. + Defaults to [(256, 256), (256, 256), (256, 256), (256, 256)]. global_pooling_schemes: The list global pooling schemes to use. Options are: "min", "max", "mean", and "sum". + use_global_features: Whether to use global features after pooling. + use_post_processing_layers: Whether to use post-processing layers + after the `DynTrans` layers. """ - # DynEdge layer sizes + # DynTrans layer sizes if dyntrans_layer_sizes is None: dyntrans_layer_sizes = [ ( @@ -66,6 +73,10 @@ def __init__( 256, 256, ), + ( + 256, + 256, + ), ] assert isinstance(dyntrans_layer_sizes, list) @@ -120,7 +131,10 @@ def __init__( self._activation = torch.nn.LeakyReLU() self._nb_inputs = nb_inputs self._nb_global_variables = 5 + nb_inputs + self._nb_neighbours = 8 self._features_subset = features_subset or [0, 1, 2, 3] + self._use_global_features = use_global_features + self._use_post_processing_layers = use_post_processing_layers self._construct_layers() def _construct_layers(self) -> None: @@ -140,16 +154,21 @@ def _construct_layers(self) -> None: self._conv_layers.append(conv_layer) nb_latent_features = sizes[-1] - post_processing_layers = [] - layer_sizes = [nb_latent_features] + list( - self._post_processing_layer_sizes - ) - for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): - post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) - post_processing_layers.append(self._activation) - last_posting_layer_output_dim = nb_out + if self._use_post_processing_layers: + post_processing_layers = [] + layer_sizes = [nb_latent_features] + list( + self._post_processing_layer_sizes + ) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + post_processing_layers.append(self._activation) + last_posting_layer_output_dim = nb_out - self._post_processing = torch.nn.Sequential(*post_processing_layers) + self._post_processing = torch.nn.Sequential( + *post_processing_layers + ) + else: + last_posting_layer_output_dim = nb_latent_features # Read-out operations nb_poolings = ( @@ -158,7 +177,8 @@ def _construct_layers(self) -> None: else 1 ) nb_latent_features = last_posting_layer_output_dim * nb_poolings - nb_latent_features += self._nb_global_variables + if self._use_global_features: + nb_latent_features += self._nb_global_variables readout_layers = [] layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) @@ -217,32 +237,31 @@ def forward(self, data: Data) -> Tensor: # Convenience variables x, edge_index, batch = data.x, data.edge_index, data.batch - global_variables = self._calculate_global_variables( - x, - edge_index, - batch, - torch.log10(data.n_pulses), - ) + if self._use_global_features: + global_variables = self._calculate_global_variables( + x, + edge_index, + batch, + torch.log10(data.n_pulses), + ) # DynEdge-convolutions for conv_layer in self._conv_layers: x = conv_layer(x, edge_index, batch) - x, mask = to_dense_batch(x, batch) - x = x[mask] - # Post-processing - x = self._post_processing(x) + if self._use_post_processing_layers: + x = self._post_processing(x) - # (Optional) Global pooling x = self._global_pooling(x, batch=batch) - x = torch.cat( - [ - x, - global_variables, - ], - dim=1, - ) + if self._use_global_features: + x = torch.cat( + [ + x, + global_variables, + ], + dim=1, + ) # Read-out x = self._readout(x)