diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index c29867155..48394ab73 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -52,7 +52,7 @@ def __init__( # Member Variables self._detector = detector - self._edge_definiton = edge_definition + self._edge_definition = edge_definition self._node_definition = node_definition if node_feature_names is None: # Assume all features in Detector is used. @@ -113,8 +113,8 @@ def forward( # type: ignore graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) # Assign edges - if self._edge_definiton is not None: - graph = self._edge_definiton(graph) + if self._edge_definition is not None: + graph = self._edge_definition(graph) else: self.warnonce( "No EdgeDefinition provided. Graphs will not have edges defined!" @@ -154,7 +154,6 @@ def forward( # type: ignore def _validate_input( self, node_features: np.array, node_feature_names: List[str] ) -> None: - # node feature matrix dimension check assert node_features.shape[1] == len(node_feature_names)