Skip to content

Commit

Permalink
shorten error strings
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Sep 22, 2023
1 parent 822c0d7 commit 1758b74
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/graphnet/models/graphs/graph_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,12 @@ def forward( # type: ignore
node_feature_names: name of each column. Shape ´[,d]´.
truth_dicts: Dictionary containing truth labels.
custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.
loss_weight_column: Name of column that holds loss weight. Defaults to None.
loss_weight_column: Name of column that holds loss weight.
Defaults to None.
loss_weight: Loss weight associated with event. Defaults to None.
loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.
loss_weight_default_value: default value for loss weight.
Used in instances where some events have
no pre-defined loss weight. Defaults to None.
data_path: Path to dataset data files. Defaults to None.
Returns:
Expand Down Expand Up @@ -146,7 +149,8 @@ def forward( # type: ignore
graph = self._edge_definition(graph)
else:
self.warnonce(
"No EdgeDefinition provided. Graphs will not have edges defined!"
"""No EdgeDefinition provided.
Graphs will not have edges defined!""" # noqa
)

# Attach data path - useful for Ensemble datasets.
Expand Down Expand Up @@ -190,11 +194,15 @@ def _validate_input(
# was instantiated with.
assert len(node_feature_names) == len(
self._node_feature_names
), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})"""
), f"""Input features ({node_feature_names}) is not what
{self.__class__.__name__} was instatiated
with ({self._node_feature_names})""" # noqa
for idx in range(len(node_feature_names)):
assert (
node_feature_names[idx] == self._node_feature_names[idx]
), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}"""
), f""" Order of node features in data
are not the same as expected. Got {node_feature_names}
vs. {self._node_feature_names}""" # noqa

def _perturb_input(self, node_features: np.ndarray) -> np.ndarray:
if isinstance(self._perturbation_dict, dict):
Expand Down Expand Up @@ -298,7 +306,8 @@ def _add_features_individually(
graph[feature] = graph.x[:, index].detach()
else:
self.warnonce(
"""Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature."""
"""Cannot assign graph['x']. This field is reserved
for node features. Please rename your input feature.""" # noqa
)
return graph

Expand Down

0 comments on commit 1758b74

Please sign in to comment.