Skip to content

Commit

Permalink
TITO more customizability
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Jan 22, 2024
1 parent 8718ff8 commit d511f1c
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/graphnet/models/gnn/dynedge_kaggle_tito.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
global_pooling_schemes: List[str] = ["max"],
use_global_features: bool = True,
use_post_processing_layers: bool = True,
post_processing_layer_sizes: Optional[List[int]] = None,
readout_layer_sizes: Optional[List[int]] = None,
n_head: int = 16,
nb_neighbours: int = 8,
):
"""Construct `DynEdgeTITO`.
Expand All @@ -57,6 +61,11 @@ def __init__(
use_global_features: Whether to use global features after pooling.
use_post_processing_layers: Whether to use post-processing layers
after the `DynTrans` layers.
post_processing_layer_sizes: (Optional) The layer sizes used in the post-processing layers. Defaults to [336, 256].
readout_layer_sizes: (Optional) The layer sizes used in the readout layers. Defaults to [256, 128].
n_head: The number of heads to use in the `DynTrans` layer.
nb_neighbours: The number of neighbours to use in the `DynTrans`
layer.
"""
# DynTrans layer sizes
if dyntrans_layer_sizes is None:
Expand Down Expand Up @@ -90,18 +99,20 @@ def __init__(
self._dyntrans_layer_sizes = dyntrans_layer_sizes

# Post-processing layer sizes
post_processing_layer_sizes = [
336,
256,
]
if post_processing_layer_sizes is None:
post_processing_layer_sizes = [
336,
256,
]

self._post_processing_layer_sizes = post_processing_layer_sizes

# Read-out layer sizes
readout_layer_sizes = [
256,
128,
]
if readout_layer_sizes is None:
readout_layer_sizes = [
256,
128,
]

self._readout_layer_sizes = readout_layer_sizes

Expand Down Expand Up @@ -131,10 +142,11 @@ def __init__(
self._activation = torch.nn.LeakyReLU()
self._nb_inputs = nb_inputs
self._nb_global_variables = 5 + nb_inputs
self._nb_neighbours = 8
self._nb_neighbours = nb_neighbours
self._features_subset = features_subset or [0, 1, 2, 3]
self._use_global_features = use_global_features
self._use_post_processing_layers = use_post_processing_layers
self._n_head = n_head
self._construct_layers()

def _construct_layers(self) -> None:
Expand All @@ -149,7 +161,7 @@ def _construct_layers(self) -> None:
[nb_latent_features] + list(sizes),
aggr="max",
features_subset=self._features_subset,
n_head=8,
n_head=self._n_head,
)
self._conv_layers.append(conv_layer)
nb_latent_features = sizes[-1]
Expand Down

0 comments on commit d511f1c

Please sign in to comment.