From 75ee094784f159082de6874ae26cc2676265cf63 Mon Sep 17 00:00:00 2001 From: AMHermansen Date: Mon, 11 Sep 2023 15:12:42 +0200 Subject: [PATCH] fix typo in edge definition --- src/graphnet/models/graphs/graph_definition.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index c29867155..48394ab73 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -52,7 +52,7 @@ def __init__( # Member Variables self._detector = detector - self._edge_definiton = edge_definition + self._edge_definition = edge_definition self._node_definition = node_definition if node_feature_names is None: # Assume all features in Detector is used. @@ -113,8 +113,8 @@ def forward( # type: ignore graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) # Assign edges - if self._edge_definiton is not None: - graph = self._edge_definiton(graph) + if self._edge_definition is not None: + graph = self._edge_definition(graph) else: self.warnonce( "No EdgeDefinition provided. Graphs will not have edges defined!" @@ -154,7 +154,6 @@ def forward( # type: ignore def _validate_input( self, node_features: np.array, node_feature_names: List[str] ) -> None: - # node feature matrix dimension check assert node_features.shape[1] == len(node_feature_names)