diff --git a/src/graphnet/models/gnn/RNN_tito.py b/src/graphnet/models/gnn/RNN_tito.py new file mode 100644 index 000000000..8facecd45 --- /dev/null +++ b/src/graphnet/models/gnn/RNN_tito.py @@ -0,0 +1,129 @@ +"""RNN_DynEdge model implementation.""" +from typing import List, Optional, Tuple, Union + +import torch +from graphnet.models.gnn.gnn import GNN +from graphnet.models.gnn.dynedge import DynEdge +from graphnet.models.gnn.dynedge_kaggle_tito import DynEdgeTITO +from graphnet.models.rnn.node_rnn import Node_RNN + +# from graphnet.models.rnn.dom_window_rnn import Dom_Window_RNN +from graphnet.models.rnn.node_transformer import Node_Transformer + +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data + + +class RNN_TITO(GNN): + """The RNN_DynEdge model class. + + Combines the Node_RNN and DynEdgeTITO models, intended for data with large + amount of DOM activations per event. This model works only with non- + standard dataset specific to the Node_RNN model see Node_RNN for more + details. + """ + + @save_model_config + def __init__( + self, + nb_inputs: int, + *, + nb_neighbours: int = 8, + RNN_layers: int = 2, + RNN_hidden_size: int = 64, + RNN_dropout: float = 0.5, + features_subset: Optional[List[int]] = None, + dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: List[str] = ["max"], + embedding_dim: Optional[int] = None, + n_head: int = 16, + use_global_features: bool = True, + use_post_processing_layers: bool = True, + ): + """Initialize the RNN_DynEdge model. + + Args: + nb_inputs (int): Number of input features. + nb_neighbours (int, optional): Number of neighbours to consider. + Defaults to 8. + RNN_layers (int, optional): Number of RNN layers. + Defaults to 1. + RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. + Defaults to 64. + RNN_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5. + features_subset (List[int], optional): The subset of latent + features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] + dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model. + post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model. + readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model. + global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None. + embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding. + n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16. + use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True. + use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True. + """ + self._nb_neighbours = nb_neighbours + self._nb_inputs = nb_inputs + self._RNN_layers = RNN_layers + self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size + self._RNN_dropout = RNN_dropout + self._embedding_dim = embedding_dim + self._n_head = n_head + self._use_global_features = use_global_features + self._use_post_processing_layers = use_post_processing_layers + + self._features_subset = features_subset + if dyntrans_layer_sizes is None: + dyntrans_layer_sizes = [ + (256, 256), + (256, 256), + (256, 256), + (256, 256), + ] + else: + dyntrans_layer_sizes = [ + tuple(layer_sizes) for layer_sizes in dyntrans_layer_sizes + ] + + self._dyntrans_layer_sizes = dyntrans_layer_sizes + self._post_processing_layer_sizes = post_processing_layer_sizes + self._global_pooling_schemes = global_pooling_schemes + if readout_layer_sizes is None: + readout_layer_sizes = [ + 256, + 128, + ] + self._readout_layer_sizes = readout_layer_sizes + + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + self._rnn = Node_RNN( + num_layers=self._RNN_layers, + nb_inputs=2, + hidden_size=self._RNN_hidden_size, + RNN_dropout=self._RNN_dropout, + embedding_dim=self._embedding_dim, + ) + + self._dynedge_tito = DynEdgeTITO( + nb_inputs=self._RNN_hidden_size + 5, + dyntrans_layer_sizes=self._dyntrans_layer_sizes, + features_subset=self._features_subset, + global_pooling_schemes=self._global_pooling_schemes, + use_global_features=self._use_global_features, + use_post_processing_layers=self._use_post_processing_layers, + post_processing_layer_sizes=self._post_processing_layer_sizes, + readout_layer_sizes=self._readout_layer_sizes, + n_head=self._n_head, + nb_neighbours=self._nb_neighbours, + ) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass of the RNN and tito model.""" + data = self._rnn(data) + # data = self._node_transformer(data) + readout = self._dynedge_tito(data) + + return readout diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index 2abe3d358..2d3ff7910 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -4,3 +4,4 @@ from .dynedge import DynEdge from .dynedge_jinst import DynEdgeJINST from .dynedge_kaggle_tito import DynEdgeTITO +from .RNN_tito import RNN_TITO diff --git a/src/graphnet/models/rnn/__init__.py b/src/graphnet/models/rnn/__init__.py new file mode 100644 index 000000000..21d29d7e7 --- /dev/null +++ b/src/graphnet/models/rnn/__init__.py @@ -0,0 +1,3 @@ +"""Recurrent neural network specific modules.""" + +from .node_rnn import Node_RNN diff --git a/src/graphnet/models/rnn/node_rnn.py b/src/graphnet/models/rnn/node_rnn.py new file mode 100644 index 000000000..a9855bce1 --- /dev/null +++ b/src/graphnet/models/rnn/node_rnn.py @@ -0,0 +1,85 @@ +"""Implementation of the NodeTimeRNN model. + +(cannot be used as a standalone model) +""" +import torch + +from graphnet.models.gnn.gnn import GNN +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data +from typing import Optional + +from graphnet.models.components.embedding import SinusoidalPosEmb + + +class Node_RNN(GNN): + """Implementation of the RNN model architecture. + + The model takes as input the typical DOM data format and transforms it into + a time series of DOM activations pr. DOM. before applying a RNN layer and + outputting the an RNN output for each DOM. This model is in it's current + state not intended to be used as a standalone model. Furthermore, it needs + to be used with a time-series dataset and a "cutter" (see + NodeAsDOMTimeSeries), which is not standard in the graphnet framework. + """ + + @save_model_config + def __init__( + self, + nb_inputs: int, + hidden_size: int, + num_layers: int, + RNN_dropout: float = 0.5, + embedding_dim: int = 0, + ) -> None: + """Construct `NodeTimeRNN`. + + Args: + nb_inputs: Number of features in the input data. + hidden_size: Number of features for the RNN output and hidden layers. + num_layers: Number of layers in the RNN. + nb_neighbours: Number of neighbours to use when reconstructing the graph representation. + RNN_dropout: Dropout fractio to use in the RNN. Defaults to 0.5. + embedding_dim: Embedding dimension of the RNN. Defaults to no embedding. + """ + self._num_layers = num_layers + self._hidden_size = hidden_size + self._embedding_dim = embedding_dim + self._nb_inputs = nb_inputs + + super().__init__(nb_inputs, hidden_size + 5) + + if self._embedding_dim != 0: + self._nb_inputs = self._embedding_dim * 2 * nb_inputs + + self._rnn = torch.nn.GRU( + num_layers=self._num_layers, + input_size=self._nb_inputs, + hidden_size=self._hidden_size, + batch_first=True, + dropout=RNN_dropout, + ) + self._emb = SinusoidalPosEmb(dim=self._embedding_dim) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass to the GNN.""" + cutter = data.cutter.cumsum(0)[:-1] + # Optional embedding of the time and charge time series data. + if self._embedding_dim != 0: + time_series = self._emb(data.time_series * 4096).reshape( + ( + data.time_series.shape[0], + self._embedding_dim * 2 * data.time_series.shape[-1], + ) + ) + else: + time_series = data.time_series + + time_series = torch.nn.utils.rnn.pack_sequence( + time_series.tensor_split(cutter.cpu()), enforce_sorted=False + ) + # apply RNN per DOM irrespective of batch and return the final state. + rnn_out = self._rnn(time_series)[-1][0] + # combine the RNN output with the DOM summary features + data.x = torch.hstack([data.x, rnn_out]) + return data