Skip to content

Commit

Permalink
More adjustability to TITO
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Jan 22, 2024
1 parent 51a8817 commit 5695a53
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions src/graphnet/models/gnn/dynedge_kaggle_tito.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(
global_pooling_schemes: List[str] = ["max"],
use_global_features: bool = True,
use_post_processing_layers: bool = True,
post_processing_layer_sizes: List[int] = None,
readout_layer_sizes: List[int] = None,
n_head: int = 8,
nb_neighbours: int = 8,
):
"""Construct `DynEdgeTITO`.
Expand All @@ -53,8 +57,12 @@ def __init__(
global_pooling_schemes: The list global pooling schemes to use.
Options are: "min", "max", "mean", and "sum".
use_global_features: Whether to use global features after pooling.
use_post_processing_layers: Whether to use post-processing layers
after the `DynTrans` layers.
use_post_processing_layers: Whether to use post-processing layers after the `DynTrans` layers.
post_processing_layer_sizes: (Optional) The layer sizes used in the post-processing layers. Defaults to [336, 256].
readout_layer_sizes: (Optional) The layer sizes used in the readout layers. Defaults to [256, 128].
n_head: The number of heads to use in the `DynTrans` layer.
nb_neighbours: The number of neighbours to use in the `DynTrans`
layer.
"""
# DynTrans layer sizes
if dyntrans_layer_sizes is None:
Expand Down Expand Up @@ -88,18 +96,20 @@ def __init__(
self._dyntrans_layer_sizes = dyntrans_layer_sizes

# Post-processing layer sizes
post_processing_layer_sizes = [
336,
256,
]
if post_processing_layer_sizes is None:
post_processing_layer_sizes = [
336,
256,
]

self._post_processing_layer_sizes = post_processing_layer_sizes

# Read-out layer sizes
readout_layer_sizes = [
256,
128,
]
if readout_layer_sizes is None:
readout_layer_sizes = [
256,
128,
]

self._readout_layer_sizes = readout_layer_sizes

Expand Down Expand Up @@ -129,10 +139,11 @@ def __init__(
self._activation = torch.nn.LeakyReLU()
self._nb_inputs = nb_inputs
self._nb_global_variables = 5 + nb_inputs
self._nb_neighbours = 8
self._nb_neighbours = nb_neighbours
self._features_subset = features_subset or [0, 1, 2, 3]
self._use_global_features = use_global_features
self._use_post_processing_layers = use_post_processing_layers
self._n_head = n_head
self._construct_layers()

def _construct_layers(self) -> None:
Expand All @@ -147,7 +158,7 @@ def _construct_layers(self) -> None:
[nb_latent_features] + list(sizes),
aggr="max",
features_subset=self._features_subset,
n_head=8,
n_head=self._n_head,
)
self._conv_layers.append(conv_layer)
nb_latent_features = sizes[-1]
Expand Down

0 comments on commit 5695a53

Please sign in to comment.