forked from graphnet-team/graphnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5695a53
commit dd504bd
Showing
4 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""RNN_DynEdge model implementation.""" | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from graphnet.models.gnn.gnn import GNN | ||
from graphnet.models.gnn.dynedge import DynEdge | ||
from graphnet.models.gnn.dynedge_kaggle_tito import DynEdgeTITO | ||
from graphnet.models.rnn.node_rnn import Node_RNN | ||
|
||
# from graphnet.models.rnn.dom_window_rnn import Dom_Window_RNN | ||
from graphnet.models.rnn.node_transformer import Node_Transformer | ||
|
||
from graphnet.utilities.config import save_model_config | ||
from torch_geometric.data import Data | ||
|
||
|
||
class RNN_TITO(GNN): | ||
"""The RNN_DynEdge model class. | ||
Combines the Node_RNN and DynEdgeTITO models, intended for data with large | ||
amount of DOM activations per event. This model works only with non- | ||
standard dataset specific to the Node_RNN model see Node_RNN for more | ||
details. | ||
""" | ||
|
||
@save_model_config | ||
def __init__( | ||
self, | ||
nb_inputs: int, | ||
*, | ||
nb_neighbours: int = 8, | ||
RNN_layers: int = 2, | ||
RNN_hidden_size: int = 64, | ||
RNN_dropout: float = 0.5, | ||
features_subset: Optional[List[int]] = None, | ||
dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, | ||
post_processing_layer_sizes: Optional[List[int]] = None, | ||
readout_layer_sizes: Optional[List[int]] = None, | ||
global_pooling_schemes: List[str] = ["max"], | ||
embedding_dim: Optional[int] = None, | ||
n_head: int = 16, | ||
use_global_features: bool = True, | ||
use_post_processing_layers: bool = True, | ||
): | ||
"""Initialize the RNN_DynEdge model. | ||
Args: | ||
nb_inputs (int): Number of input features. | ||
nb_neighbours (int, optional): Number of neighbours to consider. | ||
Defaults to 8. | ||
RNN_layers (int, optional): Number of RNN layers. | ||
Defaults to 1. | ||
RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. | ||
Defaults to 64. | ||
RNN_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5. | ||
features_subset (List[int], optional): The subset of latent | ||
features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] | ||
dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model. | ||
post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model. | ||
readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model. | ||
global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None. | ||
embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding. | ||
n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16. | ||
use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True. | ||
use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True. | ||
""" | ||
self._nb_neighbours = nb_neighbours | ||
self._nb_inputs = nb_inputs | ||
self._RNN_layers = RNN_layers | ||
self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size | ||
self._RNN_dropout = RNN_dropout | ||
self._embedding_dim = embedding_dim | ||
self._n_head = n_head | ||
self._use_global_features = use_global_features | ||
self._use_post_processing_layers = use_post_processing_layers | ||
|
||
self._features_subset = features_subset | ||
if dyntrans_layer_sizes is None: | ||
dyntrans_layer_sizes = [ | ||
(256, 256), | ||
(256, 256), | ||
(256, 256), | ||
(256, 256), | ||
] | ||
else: | ||
dyntrans_layer_sizes = [ | ||
tuple(layer_sizes) for layer_sizes in dyntrans_layer_sizes | ||
] | ||
|
||
self._dyntrans_layer_sizes = dyntrans_layer_sizes | ||
self._post_processing_layer_sizes = post_processing_layer_sizes | ||
self._global_pooling_schemes = global_pooling_schemes | ||
if readout_layer_sizes is None: | ||
readout_layer_sizes = [ | ||
256, | ||
128, | ||
] | ||
self._readout_layer_sizes = readout_layer_sizes | ||
|
||
super().__init__(nb_inputs, self._readout_layer_sizes[-1]) | ||
|
||
self._rnn = Node_RNN( | ||
num_layers=self._RNN_layers, | ||
nb_inputs=2, | ||
hidden_size=self._RNN_hidden_size, | ||
RNN_dropout=self._RNN_dropout, | ||
embedding_dim=self._embedding_dim, | ||
) | ||
|
||
self._dynedge_tito = DynEdgeTITO( | ||
nb_inputs=self._RNN_hidden_size + 5, | ||
dyntrans_layer_sizes=self._dyntrans_layer_sizes, | ||
features_subset=self._features_subset, | ||
global_pooling_schemes=self._global_pooling_schemes, | ||
use_global_features=self._use_global_features, | ||
use_post_processing_layers=self._use_post_processing_layers, | ||
post_processing_layer_sizes=self._post_processing_layer_sizes, | ||
readout_layer_sizes=self._readout_layer_sizes, | ||
n_head=self._n_head, | ||
nb_neighbours=self._nb_neighbours, | ||
) | ||
|
||
def forward(self, data: Data) -> torch.Tensor: | ||
"""Apply learnable forward pass of the RNN and tito model.""" | ||
data = self._rnn(data) | ||
# data = self._node_transformer(data) | ||
readout = self._dynedge_tito(data) | ||
|
||
return readout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Recurrent neural network specific modules.""" | ||
|
||
from .node_rnn import Node_RNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Implementation of the NodeTimeRNN model. | ||
(cannot be used as a standalone model) | ||
""" | ||
import torch | ||
|
||
from graphnet.models.gnn.gnn import GNN | ||
from graphnet.utilities.config import save_model_config | ||
from torch_geometric.data import Data | ||
from typing import Optional | ||
|
||
from graphnet.models.components.embedding import SinusoidalPosEmb | ||
|
||
|
||
class Node_RNN(GNN): | ||
"""Implementation of the RNN model architecture. | ||
The model takes as input the typical DOM data format and transforms it into | ||
a time series of DOM activations pr. DOM. before applying a RNN layer and | ||
outputting the an RNN output for each DOM. This model is in it's current | ||
state not intended to be used as a standalone model. Furthermore, it needs | ||
to be used with a time-series dataset and a "cutter" (see | ||
NodeAsDOMTimeSeries), which is not standard in the graphnet framework. | ||
""" | ||
|
||
@save_model_config | ||
def __init__( | ||
self, | ||
nb_inputs: int, | ||
hidden_size: int, | ||
num_layers: int, | ||
RNN_dropout: float = 0.5, | ||
embedding_dim: int = 0, | ||
) -> None: | ||
"""Construct `NodeTimeRNN`. | ||
Args: | ||
nb_inputs: Number of features in the input data. | ||
hidden_size: Number of features for the RNN output and hidden layers. | ||
num_layers: Number of layers in the RNN. | ||
nb_neighbours: Number of neighbours to use when reconstructing the graph representation. | ||
RNN_dropout: Dropout fractio to use in the RNN. Defaults to 0.5. | ||
embedding_dim: Embedding dimension of the RNN. Defaults to no embedding. | ||
""" | ||
self._num_layers = num_layers | ||
self._hidden_size = hidden_size | ||
self._embedding_dim = embedding_dim | ||
self._nb_inputs = nb_inputs | ||
|
||
super().__init__(nb_inputs, hidden_size + 5) | ||
|
||
if self._embedding_dim != 0: | ||
self._nb_inputs = self._embedding_dim * 2 * nb_inputs | ||
|
||
self._rnn = torch.nn.GRU( | ||
num_layers=self._num_layers, | ||
input_size=self._nb_inputs, | ||
hidden_size=self._hidden_size, | ||
batch_first=True, | ||
dropout=RNN_dropout, | ||
) | ||
self._emb = SinusoidalPosEmb(dim=self._embedding_dim) | ||
|
||
def forward(self, data: Data) -> torch.Tensor: | ||
"""Apply learnable forward pass to the GNN.""" | ||
cutter = data.cutter.cumsum(0)[:-1] | ||
# Optional embedding of the time and charge time series data. | ||
if self._embedding_dim != 0: | ||
time_series = self._emb(data.time_series * 4096).reshape( | ||
( | ||
data.time_series.shape[0], | ||
self._embedding_dim * 2 * data.time_series.shape[-1], | ||
) | ||
) | ||
else: | ||
time_series = data.time_series | ||
|
||
time_series = torch.nn.utils.rnn.pack_sequence( | ||
time_series.tensor_split(cutter.cpu()), enforce_sorted=False | ||
) | ||
# apply RNN per DOM irrespective of batch and return the final state. | ||
rnn_out = self._rnn(time_series)[-1][0] | ||
# combine the RNN output with the DOM summary features | ||
data.x = torch.hstack([data.x, rnn_out]) | ||
return data |