From 873304db21860e08e26f3b0877962eb8c6edd0ba Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Fri, 8 Sep 2023 09:24:11 +0200 Subject: [PATCH] automatic attribute length adjustment --- src/graphnet/models/model.py | 29 ++++++++++++++++++++------- src/graphnet/models/standard_model.py | 4 ---- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index 00acf9109..193746919 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -181,9 +181,7 @@ def predict_as_dataframe( dataloader: DataLoader, prediction_columns: List[str], *, - node_level: bool = False, additional_attributes: Optional[List[str]] = None, - index_column: str = "event_no", gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", ) -> pd.DataFrame: @@ -231,12 +229,29 @@ def predict_as_dataframe( ) for batch in dataloader: for attr in attributes: - attribute = batch[attr].detach().cpu().numpy() - if node_level: - if attr == index_column: - attribute = np.repeat( - attribute, batch.n_pulses.detach().cpu().numpy() + attribute = batch[attr] + if isinstance(attribute, torch.Tensor): + attribute = attribute.detach().cpu().numpy() + + # Check if node level predictions + # If true, additional attributes are repeated + # to make dimensions fit + if len(attribute) < np.sum( + batch.n_pulses.detach().cpu().numpy() + ): + attribute = np.repeat( + attribute, batch.n_pulses.detach().cpu().numpy() + ) + try: + assert len(attribute) == len(batch.x) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f"of additional attribute {attr} to match length of" + f"predictions. Make sure {attr} is a graph-level or" + "node-level attribute. Attribute skipped." ) + pass attributes[attr].extend(attribute) data = np.concatenate( diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 04b60628f..1d439133f 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -191,9 +191,7 @@ def predict_as_dataframe( dataloader: DataLoader, prediction_columns: Optional[List[str]] = None, *, - node_level: bool = False, additional_attributes: Optional[List[str]] = None, - index_column: str = "event_no", gpus: Optional[Union[List[int], int]] = None, distribution_strategy: Optional[str] = "auto", ) -> pd.DataFrame: @@ -207,9 +205,7 @@ def predict_as_dataframe( return super().predict_as_dataframe( dataloader=dataloader, prediction_columns=prediction_columns, - node_level=node_level, additional_attributes=additional_attributes, - index_column=index_column, gpus=gpus, distribution_strategy=distribution_strategy, )