Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redefinition of Tito model #611

Merged
merged 12 commits into from
Oct 24, 2023
Merged
6 changes: 5 additions & 1 deletion examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def main(
# Building model
gnn = DynEdgeTITO(
nb_inputs=graph_definition.nb_outputs,
features_subset=[0, 1, 2, 3],
dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)],
global_pooling_schemes=["max"],
use_global_features=True,
use_post_processing_layers=True,
)
task = DirectionReconstructionWithKappa(
hidden_size=gnn.nb_outputs,
Expand Down Expand Up @@ -212,7 +216,7 @@ def main(
"Name of feature to use as regression target (default: "
"%(default)s)"
),
default=["direction"],
default="direction",
)

parser.add_argument(
Expand Down
85 changes: 52 additions & 33 deletions src/graphnet/models/gnn/dynedge_kaggle_tito.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
Solution by TITO.
"""

from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union

import torch
from torch import Tensor, LongTensor

from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

from graphnet.models.components.layers import DynTrans
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import calculate_xyzt_homophily

Expand All @@ -30,16 +30,19 @@


class DynEdgeTITO(GNN):
"""DynEdge (dynamical edge convolutional) model."""
"""DynEdgeTITO (dynamical edge convolutional with Transformer) model."""

@save_model_config
def __init__(
self,
nb_inputs: int,
features_subset: List[int] = None,
dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
global_pooling_schemes: List[str] = ["max"],
use_global_features: bool = True,
use_post_processing_layers: bool = True,
):
"""Construct `DynEdge`.
"""Construct `DynEdgeTITO`.

Args:
nb_inputs: Number of input features on each node.
Expand All @@ -48,10 +51,14 @@ def __init__(
neighbours clustering. Defaults to [0,1,2,3].
dyntrans_layer_sizes: The layer sizes, or latent feature dimenions,
used in the `DynTrans` layer.
Defaults to [(256, 256), (256, 256), (256, 256), (256, 256)].
global_pooling_schemes: The list global pooling schemes to use.
Options are: "min", "max", "mean", and "sum".
use_global_features: Whether to use global features after pooling.
use_post_processing_layers: Whether to use post-processing layers
after the `DynTrans` layers.
"""
# DynEdge layer sizes
# DynTrans layer sizes
if dyntrans_layer_sizes is None:
dyntrans_layer_sizes = [
(
Expand All @@ -66,6 +73,10 @@ def __init__(
256,
256,
),
(
256,
256,
),
]

assert isinstance(dyntrans_layer_sizes, list)
Expand Down Expand Up @@ -120,7 +131,10 @@ def __init__(
self._activation = torch.nn.LeakyReLU()
self._nb_inputs = nb_inputs
self._nb_global_variables = 5 + nb_inputs
self._nb_neighbours = 8
self._features_subset = features_subset or [0, 1, 2, 3]
self._use_global_features = use_global_features
self._use_post_processing_layers = use_post_processing_layers
self._construct_layers()

def _construct_layers(self) -> None:
Expand All @@ -140,16 +154,21 @@ def _construct_layers(self) -> None:
self._conv_layers.append(conv_layer)
nb_latent_features = sizes[-1]

post_processing_layers = []
layer_sizes = [nb_latent_features] + list(
self._post_processing_layer_sizes
)
for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
post_processing_layers.append(torch.nn.Linear(nb_in, nb_out))
post_processing_layers.append(self._activation)
last_posting_layer_output_dim = nb_out
if self._use_post_processing_layers:
post_processing_layers = []
layer_sizes = [nb_latent_features] + list(
self._post_processing_layer_sizes
)
for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
post_processing_layers.append(torch.nn.Linear(nb_in, nb_out))
post_processing_layers.append(self._activation)
last_posting_layer_output_dim = nb_out

self._post_processing = torch.nn.Sequential(*post_processing_layers)
self._post_processing = torch.nn.Sequential(
*post_processing_layers
)
else:
last_posting_layer_output_dim = nb_latent_features

# Read-out operations
nb_poolings = (
Expand All @@ -158,7 +177,8 @@ def _construct_layers(self) -> None:
else 1
)
nb_latent_features = last_posting_layer_output_dim * nb_poolings
nb_latent_features += self._nb_global_variables
if self._use_global_features:
nb_latent_features += self._nb_global_variables

readout_layers = []
layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes)
Expand Down Expand Up @@ -217,32 +237,31 @@ def forward(self, data: Data) -> Tensor:
# Convenience variables
x, edge_index, batch = data.x, data.edge_index, data.batch

global_variables = self._calculate_global_variables(
x,
edge_index,
batch,
torch.log10(data.n_pulses),
)
if self._use_global_features:
global_variables = self._calculate_global_variables(
x,
edge_index,
batch,
torch.log10(data.n_pulses),
)

# DynEdge-convolutions
for conv_layer in self._conv_layers:
x = conv_layer(x, edge_index, batch)

x, mask = to_dense_batch(x, batch)
x = x[mask]

# Post-processing
x = self._post_processing(x)
if self._use_post_processing_layers:
x = self._post_processing(x)

# (Optional) Global pooling
x = self._global_pooling(x, batch=batch)
x = torch.cat(
[
x,
global_variables,
],
dim=1,
)
if self._use_global_features:
x = torch.cat(
[
x,
global_variables,
],
dim=1,
)

# Read-out
x = self._readout(x)
Expand Down