Skip to content

Commit

Permalink
add the RNN and RNN_TITO modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Jan 22, 2024
1 parent 5695a53 commit dd504bd
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
129 changes: 129 additions & 0 deletions src/graphnet/models/gnn/RNN_tito.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""RNN_DynEdge model implementation."""
from typing import List, Optional, Tuple, Union

import torch
from graphnet.models.gnn.gnn import GNN
from graphnet.models.gnn.dynedge import DynEdge
from graphnet.models.gnn.dynedge_kaggle_tito import DynEdgeTITO
from graphnet.models.rnn.node_rnn import Node_RNN

# from graphnet.models.rnn.dom_window_rnn import Dom_Window_RNN
from graphnet.models.rnn.node_transformer import Node_Transformer

from graphnet.utilities.config import save_model_config
from torch_geometric.data import Data


class RNN_TITO(GNN):
"""The RNN_DynEdge model class.
Combines the Node_RNN and DynEdgeTITO models, intended for data with large
amount of DOM activations per event. This model works only with non-
standard dataset specific to the Node_RNN model see Node_RNN for more
details.
"""

@save_model_config
def __init__(
self,
nb_inputs: int,
*,
nb_neighbours: int = 8,
RNN_layers: int = 2,
RNN_hidden_size: int = 64,
RNN_dropout: float = 0.5,
features_subset: Optional[List[int]] = None,
dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
post_processing_layer_sizes: Optional[List[int]] = None,
readout_layer_sizes: Optional[List[int]] = None,
global_pooling_schemes: List[str] = ["max"],
embedding_dim: Optional[int] = None,
n_head: int = 16,
use_global_features: bool = True,
use_post_processing_layers: bool = True,
):
"""Initialize the RNN_DynEdge model.
Args:
nb_inputs (int): Number of input features.
nb_neighbours (int, optional): Number of neighbours to consider.
Defaults to 8.
RNN_layers (int, optional): Number of RNN layers.
Defaults to 1.
RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN.
Defaults to 64.
RNN_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5.
features_subset (List[int], optional): The subset of latent
features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3]
dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model.
post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model.
readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model.
global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None.
embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding.
n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16.
use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True.
use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True.
"""
self._nb_neighbours = nb_neighbours
self._nb_inputs = nb_inputs
self._RNN_layers = RNN_layers
self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size
self._RNN_dropout = RNN_dropout
self._embedding_dim = embedding_dim
self._n_head = n_head
self._use_global_features = use_global_features
self._use_post_processing_layers = use_post_processing_layers

self._features_subset = features_subset
if dyntrans_layer_sizes is None:
dyntrans_layer_sizes = [
(256, 256),
(256, 256),
(256, 256),
(256, 256),
]
else:
dyntrans_layer_sizes = [
tuple(layer_sizes) for layer_sizes in dyntrans_layer_sizes
]

self._dyntrans_layer_sizes = dyntrans_layer_sizes
self._post_processing_layer_sizes = post_processing_layer_sizes
self._global_pooling_schemes = global_pooling_schemes
if readout_layer_sizes is None:
readout_layer_sizes = [
256,
128,
]
self._readout_layer_sizes = readout_layer_sizes

super().__init__(nb_inputs, self._readout_layer_sizes[-1])

self._rnn = Node_RNN(
num_layers=self._RNN_layers,
nb_inputs=2,
hidden_size=self._RNN_hidden_size,
RNN_dropout=self._RNN_dropout,
embedding_dim=self._embedding_dim,
)

self._dynedge_tito = DynEdgeTITO(
nb_inputs=self._RNN_hidden_size + 5,
dyntrans_layer_sizes=self._dyntrans_layer_sizes,
features_subset=self._features_subset,
global_pooling_schemes=self._global_pooling_schemes,
use_global_features=self._use_global_features,
use_post_processing_layers=self._use_post_processing_layers,
post_processing_layer_sizes=self._post_processing_layer_sizes,
readout_layer_sizes=self._readout_layer_sizes,
n_head=self._n_head,
nb_neighbours=self._nb_neighbours,
)

def forward(self, data: Data) -> torch.Tensor:
"""Apply learnable forward pass of the RNN and tito model."""
data = self._rnn(data)
# data = self._node_transformer(data)
readout = self._dynedge_tito(data)

return readout
1 change: 1 addition & 0 deletions src/graphnet/models/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .dynedge import DynEdge
from .dynedge_jinst import DynEdgeJINST
from .dynedge_kaggle_tito import DynEdgeTITO
from .RNN_tito import RNN_TITO
3 changes: 3 additions & 0 deletions src/graphnet/models/rnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Recurrent neural network specific modules."""

from .node_rnn import Node_RNN
85 changes: 85 additions & 0 deletions src/graphnet/models/rnn/node_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Implementation of the NodeTimeRNN model.
(cannot be used as a standalone model)
"""
import torch

from graphnet.models.gnn.gnn import GNN
from graphnet.utilities.config import save_model_config
from torch_geometric.data import Data
from typing import Optional

from graphnet.models.components.embedding import SinusoidalPosEmb


class Node_RNN(GNN):
"""Implementation of the RNN model architecture.
The model takes as input the typical DOM data format and transforms it into
a time series of DOM activations pr. DOM. before applying a RNN layer and
outputting the an RNN output for each DOM. This model is in it's current
state not intended to be used as a standalone model. Furthermore, it needs
to be used with a time-series dataset and a "cutter" (see
NodeAsDOMTimeSeries), which is not standard in the graphnet framework.
"""

@save_model_config
def __init__(
self,
nb_inputs: int,
hidden_size: int,
num_layers: int,
RNN_dropout: float = 0.5,
embedding_dim: int = 0,
) -> None:
"""Construct `NodeTimeRNN`.
Args:
nb_inputs: Number of features in the input data.
hidden_size: Number of features for the RNN output and hidden layers.
num_layers: Number of layers in the RNN.
nb_neighbours: Number of neighbours to use when reconstructing the graph representation.
RNN_dropout: Dropout fractio to use in the RNN. Defaults to 0.5.
embedding_dim: Embedding dimension of the RNN. Defaults to no embedding.
"""
self._num_layers = num_layers
self._hidden_size = hidden_size
self._embedding_dim = embedding_dim
self._nb_inputs = nb_inputs

super().__init__(nb_inputs, hidden_size + 5)

if self._embedding_dim != 0:
self._nb_inputs = self._embedding_dim * 2 * nb_inputs

self._rnn = torch.nn.GRU(
num_layers=self._num_layers,
input_size=self._nb_inputs,
hidden_size=self._hidden_size,
batch_first=True,
dropout=RNN_dropout,
)
self._emb = SinusoidalPosEmb(dim=self._embedding_dim)

def forward(self, data: Data) -> torch.Tensor:
"""Apply learnable forward pass to the GNN."""
cutter = data.cutter.cumsum(0)[:-1]
# Optional embedding of the time and charge time series data.
if self._embedding_dim != 0:
time_series = self._emb(data.time_series * 4096).reshape(
(
data.time_series.shape[0],
self._embedding_dim * 2 * data.time_series.shape[-1],
)
)
else:
time_series = data.time_series

time_series = torch.nn.utils.rnn.pack_sequence(
time_series.tensor_split(cutter.cpu()), enforce_sorted=False
)
# apply RNN per DOM irrespective of batch and return the final state.
rnn_out = self._rnn(time_series)[-1][0]
# combine the RNN output with the DOM summary features
data.x = torch.hstack([data.x, rnn_out])
return data

0 comments on commit dd504bd

Please sign in to comment.