From 058ce59bcd2fedb22b27d3d545be82b0d5aa47bd Mon Sep 17 00:00:00 2001 From: Aske-Rosted Date: Thu, 14 Sep 2023 12:50:13 +0900 Subject: [PATCH] Fix #592 --- src/graphnet/models/model.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 193746919..111a2fa76 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -227,6 +227,7 @@ def predict_as_dataframe( attributes: Dict[str, List[np.ndarray]] = OrderedDict( [(attr, []) for attr in additional_attributes] ) + for batch in dataloader: for attr in attributes: attribute = batch[attr] @@ -236,22 +237,23 @@ def predict_as_dataframe( # Check if node level predictions # If true, additional attributes are repeated # to make dimensions fit - if len(attribute) < np.sum( - batch.n_pulses.detach().cpu().numpy() - ): - attribute = np.repeat( - attribute, batch.n_pulses.detach().cpu().numpy() - ) - try: - assert len(attribute) == len(batch.x) - except AssertionError: - self.warning_once( - "Could not automatically adjust length" - f"of additional attribute {attr} to match length of" - f"predictions. Make sure {attr} is a graph-level or" - "node-level attribute. Attribute skipped." + if len(predictions) != len(dataloader.dataset): + if len(attribute) < np.sum( + batch.n_pulses.detach().cpu().numpy() + ): + attribute = np.repeat( + attribute, batch.n_pulses.detach().cpu().numpy() ) - pass + try: + assert len(attribute) == len(batch.x) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f"of additional attribute {attr} to match length of" + f"predictions. Make sure {attr} is a graph-level or" + "node-level attribute. Attribute skipped." + ) + pass attributes[attr].extend(attribute) data = np.concatenate(