From 7f7000fb113c9f43412c528d398f8bca7f89991b Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Fri, 6 Dec 2024 14:29:44 +0900 Subject: [PATCH] fix add_counts optional --- src/graphnet/models/graphs/nodes/nodes.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 59de864fd..36afb4e1d 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -169,9 +169,7 @@ def _define_output_feature_names( cluster_idx, summ_idx, new_feature_names, - ) = self._get_indices_and_feature_names( - input_feature_names, self._add_counts - ) + ) = self._get_indices_and_feature_names(input_feature_names) self._cluster_indices = cluster_idx self._summarization_indices = summ_idx return new_feature_names @@ -179,7 +177,6 @@ def _define_output_feature_names( def _get_indices_and_feature_names( self, feature_names: List[str], - add_counts: bool, ) -> Tuple[List[int], List[int], List[str]]: cluster_idx, summ_idx, summ_names = identify_indices( feature_names, self._cluster_on @@ -188,7 +185,7 @@ def _get_indices_and_feature_names( for feature in summ_names: for pct in self._percentiles: new_feature_names.append(f"{feature}_pct{pct}") - if add_counts: + if self._add_counts: # add "counts" as the last feature new_feature_names.append("counts") return cluster_idx, summ_idx, new_feature_names @@ -205,7 +202,8 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: summarization_indices=self._summarization_indices, percentiles=self._percentiles, ) - array = cluster_class.add_counts() + if self._add_counts: + array = cluster_class.add_counts() else: self.error( f"""{self.__class__.__name__} was not instatiated with