diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 4e094e6be..139d851a0 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -9,7 +9,7 @@ from graphnet.utilities.decorators import final from graphnet.models import Model from graphnet.models.graphs.utils import ( - cluster_summarize_with_percentiles, + cluster_and_pad, identify_indices, lex_sort, ice_transparency, @@ -198,13 +198,14 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: x = x.numpy() # Construct clusters with percentile-summarized features if hasattr(self, "_summarization_indices"): - array = cluster_summarize_with_percentiles( - x=x, + cluster_class = cluster_and_pad( + x=x, cluster_columns=self._cluster_indices + ) + array = cluster_class.add_percentile_summary( summarization_indices=self._summarization_indices, - cluster_indices=self._cluster_indices, percentiles=self._percentiles, - add_counts=self._add_counts, ) + array = cluster_class.add_counts() else: self.error( f"""{self.__class__.__name__} was not instatiated with