Skip to content

Commit

Permalink
Merge pull request #593 from Aske-Rosted/main
Browse files Browse the repository at this point in the history
Fix #592
  • Loading branch information
Aske-Rosted authored Sep 14, 2023
2 parents c46b8ce + 058ce59 commit 6edf835
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def predict_as_dataframe(
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
)

for batch in dataloader:
for attr in attributes:
attribute = batch[attr]
Expand All @@ -236,22 +237,23 @@ def predict_as_dataframe(
# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
if len(predictions) != len(dataloader.dataset):
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
pass
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
)
pass
attributes[attr].extend(attribute)

data = np.concatenate(
Expand Down

0 comments on commit 6edf835

Please sign in to comment.