Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Feb 4, 2024
1 parent efdadb1 commit db2fe8f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/04_training/05_train_RNN_TITO.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def main(
backbone = RNN_TITO(
nb_inputs=graph_definition.nb_outputs,
nb_neighbours=8,
RNN_layers=2,
RNN_hidden_size=64,
RNN_dropout=0.5,
rnn_layers=2,
rnn_hidden_size=64,
rnn_dropout=0.5,
features_subset=[0, 1, 2, 3],
dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)],
post_processing_layer_sizes=[336, 256],
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/models/rnn/node_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(self, data: Data) -> torch.Tensor:
)
rnn_out = self._rnn(time_series)[-1][0]
# prepare node level features
charge = data.charge.tensor_split(splitter)
charge = data.x[:, data.time_series_index[0][0]].tensor_split(splitter)
charge = torch.tensor(
[
torch.asinh(5 * torch.sum(node_charges) / 5)
Expand All @@ -114,7 +114,7 @@ def forward(self, data: Data) -> torch.Tensor:
)
batch = data.batch[x[:, -1].bool()]
x = x[x[:, -1].bool()][:, :-1]
x[:, data.features[0].index("charge")] = charge
x[:, data.time_series_index[0][0]] = charge

# combine the RNN output with the DOM summary features
data.x = torch.hstack([x, rnn_out])
Expand Down

0 comments on commit db2fe8f

Please sign in to comment.