From e6357b50a472d491b4652357676f1622f5068c85 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Thu, 24 Oct 2024 17:01:42 +0900 Subject: [PATCH] Update PercentileCluster --- src/graphnet/models/graphs/nodes/nodes.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 558ec96f4..59de864fd 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -9,7 +9,7 @@ from graphnet.utilities.decorators import final from graphnet.models import Model from graphnet.models.graphs.utils import ( - cluster_summarize_with_percentiles, + cluster_and_pad, identify_indices, lex_sort, ice_transparency, @@ -198,13 +198,14 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: x = x.numpy() # Construct clusters with percentile-summarized features if hasattr(self, "_summarization_indices"): - array = cluster_summarize_with_percentiles( - x=x, + cluster_class = cluster_and_pad( + x=x, cluster_columns=self._cluster_indices + ) + array = cluster_class.add_percentile_summary( summarization_indices=self._summarization_indices, - cluster_indices=self._cluster_indices, percentiles=self._percentiles, - add_counts=self._add_counts, ) + array = cluster_class.add_counts() else: self.error( f"""{self.__class__.__name__} was not instatiated with