diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 53c32deaf..e38308ba1 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -413,6 +413,9 @@ def predict_as_dataframe( f"number of output columns ({predictions.shape[1]}) don't match." ) + # Check if predictions are on event- or pulse-level + pulse_level_predictions = len(predictions) > len(dataloader.dataset) + # Get additional attributes attributes: Dict[str, List[np.ndarray]] = OrderedDict( [(attr, []) for attr in additional_attributes] @@ -426,25 +429,39 @@ def predict_as_dataframe( # Check if node level predictions # If true, additional attributes are repeated # to make dimensions fit - if len(predictions) != len(dataloader.dataset): + if pulse_level_predictions: if len(attribute) < np.sum( batch.n_pulses.detach().cpu().numpy() ): attribute = np.repeat( attribute, batch.n_pulses.detach().cpu().numpy() ) - try: - assert len(attribute) == len(batch.x) - except AssertionError: - self.warning_once( - "Could not automatically adjust length" - f"of additional attribute {attr} to match length of" - f"predictions. Make sure {attr} is a graph-level or" - "node-level attribute. Attribute skipped." - ) - pass attributes[attr].extend(attribute) + # Confirm that attributes match length of predictions + skip_attributes = [] + for attr in attributes.keys(): + try: + assert len(attributes[attr]) == len(predictions) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f" of additional attribute '{attr}' to match length of" + f" predictions.This error can be caused by heavy" + " disagreement between number of examples in the" + " dataset vs. actual events in the dataloader, e.g. " + " heavy filtering of events in `collate_fn` passed to" + " `dataloader`. This can also be caused by requesting" + " pulse-level attributes for `Task`s that produce" + " event-level predictions. Attribute skipped." + ) + skip_attributes.append(attr) + + # Remove bad attributes + for attr in skip_attributes: + attributes.pop(attr) + additional_attributes.remove(attr) + data = np.concatenate( [predictions] + [