From 1758b740d2fe21e501196ab2a08d2e2fed0e9325 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 22 Sep 2023 20:01:00 +0200 Subject: [PATCH] shorten error strings --- .../models/graphs/graph_definition.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 931f5e398..089eee008 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -113,9 +113,12 @@ def forward( # type: ignore node_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. - loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight_column: Name of column that holds loss weight. + Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. - loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + loss_weight_default_value: default value for loss weight. + Used in instances where some events have + no pre-defined loss weight. Defaults to None. data_path: Path to dataset data files. Defaults to None. Returns: @@ -146,7 +149,8 @@ def forward( # type: ignore graph = self._edge_definition(graph) else: self.warnonce( - "No EdgeDefinition provided. Graphs will not have edges defined!" + """No EdgeDefinition provided. + Graphs will not have edges defined!""" # noqa ) # Attach data path - useful for Ensemble datasets. @@ -190,11 +194,15 @@ def _validate_input( # was instantiated with. assert len(node_feature_names) == len( self._node_feature_names - ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + ), f"""Input features ({node_feature_names}) is not what + {self.__class__.__name__} was instatiated + with ({self._node_feature_names})""" # noqa for idx in range(len(node_feature_names)): assert ( node_feature_names[idx] == self._node_feature_names[idx] - ), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}""" + ), f""" Order of node features in data + are not the same as expected. Got {node_feature_names} + vs. {self._node_feature_names}""" # noqa def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): @@ -298,7 +306,8 @@ def _add_features_individually( graph[feature] = graph.x[:, index].detach() else: self.warnonce( - """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" + """Cannot assign graph['x']. This field is reserved + for node features. Please rename your input feature.""" # noqa ) return graph