From 5695a5338ef26f183a6c8b28aaf27432ffc1474e Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Mon, 22 Jan 2024 15:47:20 +0900 Subject: [PATCH] More adjustability to TITO --- .../models/gnn/dynedge_kaggle_tito.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 78b5aebe5..c9fc41417 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -39,6 +39,10 @@ def __init__( global_pooling_schemes: List[str] = ["max"], use_global_features: bool = True, use_post_processing_layers: bool = True, + post_processing_layer_sizes: List[int] = None, + readout_layer_sizes: List[int] = None, + n_head: int = 8, + nb_neighbours: int = 8, ): """Construct `DynEdgeTITO`. @@ -53,8 +57,12 @@ def __init__( global_pooling_schemes: The list global pooling schemes to use. Options are: "min", "max", "mean", and "sum". use_global_features: Whether to use global features after pooling. - use_post_processing_layers: Whether to use post-processing layers - after the `DynTrans` layers. + use_post_processing_layers: Whether to use post-processing layers after the `DynTrans` layers. + post_processing_layer_sizes: (Optional) The layer sizes used in the post-processing layers. Defaults to [336, 256]. + readout_layer_sizes: (Optional) The layer sizes used in the readout layers. Defaults to [256, 128]. + n_head: The number of heads to use in the `DynTrans` layer. + nb_neighbours: The number of neighbours to use in the `DynTrans` + layer. """ # DynTrans layer sizes if dyntrans_layer_sizes is None: @@ -88,18 +96,20 @@ def __init__( self._dyntrans_layer_sizes = dyntrans_layer_sizes # Post-processing layer sizes - post_processing_layer_sizes = [ - 336, - 256, - ] + if post_processing_layer_sizes is None: + post_processing_layer_sizes = [ + 336, + 256, + ] self._post_processing_layer_sizes = post_processing_layer_sizes # Read-out layer sizes - readout_layer_sizes = [ - 256, - 128, - ] + if readout_layer_sizes is None: + readout_layer_sizes = [ + 256, + 128, + ] self._readout_layer_sizes = readout_layer_sizes @@ -129,10 +139,11 @@ def __init__( self._activation = torch.nn.LeakyReLU() self._nb_inputs = nb_inputs self._nb_global_variables = 5 + nb_inputs - self._nb_neighbours = 8 + self._nb_neighbours = nb_neighbours self._features_subset = features_subset or [0, 1, 2, 3] self._use_global_features = use_global_features self._use_post_processing_layers = use_post_processing_layers + self._n_head = n_head self._construct_layers() def _construct_layers(self) -> None: @@ -147,7 +158,7 @@ def _construct_layers(self) -> None: [nb_latent_features] + list(sizes), aggr="max", features_subset=self._features_subset, - n_head=8, + n_head=self._n_head, ) self._conv_layers.append(conv_layer) nb_latent_features = sizes[-1]