Skip to content

Commit

Permalink
Merge pull request graphnet-team#658 from RasmusOrsoe/predict_as_data…
Browse files Browse the repository at this point in the history
…frame_bugfix

`predict_as_dataframe` bugfix + syntax error in `minkowski.py`
  • Loading branch information
RasmusOrsoe authored Mar 20, 2024
2 parents 6f7db39 + 7c4fdfb commit 36efb24
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def predict_as_dataframe(
f"number of output columns ({predictions.shape[1]}) don't match."
)

# Check if predictions are on event- or pulse-level
pulse_level_predictions = len(predictions) > len(dataloader.dataset)

# Get additional attributes
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
Expand All @@ -426,25 +429,39 @@ def predict_as_dataframe(
# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if len(predictions) != len(dataloader.dataset):
if pulse_level_predictions:
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
)
pass
attributes[attr].extend(attribute)

# Confirm that attributes match length of predictions
skip_attributes = []
for attr in attributes.keys():
try:
assert len(attributes[attr]) == len(predictions)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f" of additional attribute '{attr}' to match length of"
f" predictions.This error can be caused by heavy"
" disagreement between number of examples in the"
" dataset vs. actual events in the dataloader, e.g. "
" heavy filtering of events in `collate_fn` passed to"
" `dataloader`. This can also be caused by requesting"
" pulse-level attributes for `Task`s that produce"
" event-level predictions. Attribute skipped."
)
skip_attributes.append(attr)

# Remove bad attributes
for attr in skip_attributes:
attributes.pop(attr)
additional_attributes.remove(attr)

data = np.concatenate(
[predictions]
+ [
Expand Down

0 comments on commit 36efb24

Please sign in to comment.