Skip to content

Commit

Permalink
snake_case
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Feb 2, 2024
1 parent 5d79ce7 commit 9f96660
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/graphnet/models/gnn/RNN_tito.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(
nb_inputs: int,
*,
nb_neighbours: int = 8,
RNN_layers: int = 2,
RNN_hidden_size: int = 64,
RNN_dropout: float = 0.5,
rnn_layers: int = 2,
rnn_hidden_size: int = 64,
rnn_dropout: float = 0.5,
features_subset: Optional[List[int]] = None,
dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
post_processing_layer_sizes: Optional[List[int]] = None,
Expand All @@ -45,11 +45,11 @@ def __init__(
nb_inputs (int): Number of input features.
nb_neighbours (int, optional): Number of neighbours to consider.
Defaults to 8.
RNN_layers (int, optional): Number of RNN layers.
rnn_layers (int, optional): Number of RNN layers.
Defaults to 1.
RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN.
rnn_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN.
Defaults to 64.
RNN_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5.
rnn_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5.
features_subset (List[int], optional): The subset of latent
features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3]
dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model.
Expand All @@ -63,9 +63,9 @@ def __init__(
"""
self._nb_neighbours = nb_neighbours
self._nb_inputs = nb_inputs
self._RNN_layers = RNN_layers
self._RNN_hidden_size = RNN_hidden_size # RNN_hidden_size
self._RNN_dropout = RNN_dropout
self._rnn_layers = rnn_layers
self._rnn_hidden_size = rnn_hidden_size
self._rnn_dropout = rnn_dropout
self._embedding_dim = embedding_dim
self._n_head = n_head
self._use_global_features = use_global_features
Expand Down Expand Up @@ -97,15 +97,15 @@ def __init__(
super().__init__(nb_inputs, self._readout_layer_sizes[-1])

self._rnn = Node_RNN(
num_layers=self._RNN_layers,
num_layers=self._rnn_layers,
nb_inputs=2,
hidden_size=self._RNN_hidden_size,
RNN_dropout=self._RNN_dropout,
hidden_size=self._rnn_hidden_size,
rnn_dropout=self._rnn_dropout,
embedding_dim=self._embedding_dim,
)

self._dynedge_tito = DynEdgeTITO(
nb_inputs=self._RNN_hidden_size + 5,
nb_inputs=self._rnn_hidden_size + 5,
dyntrans_layer_sizes=self._dyntrans_layer_sizes,
features_subset=self._features_subset,
global_pooling_schemes=self._global_pooling_schemes,
Expand Down

0 comments on commit 9f96660

Please sign in to comment.