Skip to content

Commit

Permalink
align with prehooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Nov 15, 2024
1 parent f4c4d16 commit 7fcc7fe
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions src/graphnet/models/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def identify_indices(
return cluster_indices, summarization_indices, features_for_summarization


# TODO Remove this function as it is superseded by the class cluster_and_pad wich has the same functionality
# TODO Remove this function as it is superseded by
# cluster_and_pad wich has the same functionality
def cluster_summarize_with_percentiles(
x: np.ndarray,
summarization_indices: List[int],
Expand Down Expand Up @@ -151,7 +152,9 @@ def cluster_summarize_with_percentiles(
Percentile-summarized array
"""
print(
"This function is deprecated and will be removed, use the class cluster_and_pad with add_percentile_summary instead for the same functionality"
"This function is deprecated and will be removed,",
"use the class cluster_and_pad with add_percentile_summary",
"instead for the same functionality",
)
pct_dict = {}
for feature_idx in summarization_indices:
Expand All @@ -177,7 +180,7 @@ def cluster_summarize_with_percentiles(


class cluster_and_pad:
"""cluster and pad the data for further summarization."""
"""Cluster and pad the data for further summarization."""

def __init__(self, x: np.ndarray, cluster_columns: List[int]) -> None:
"""Initialize the class with the data and cluster columns.
Expand Down Expand Up @@ -251,21 +254,25 @@ def add_charge_threshold_summary(
Args:
summarization_indices: List of column indices that defines features
that will be summarized with percentiles.
that will be summarized with percentiles.
percentiles: percentiles used to summarize `x`. E.g. [10,50,90].
charge_index: index of the charge column in the padded tensor
location: Location to insert the summarization indices in the clustered tensor defaults to adding at the end
location: Location to insert the summarization indices in the
clustered tensor defaults to adding at the end
Returns:
clustered_x: The clustered tensor with the summarization indices added
clustered_x: The clustered tensor with the summarization indices
added
Adds:
_charge_sum: Added to the class
_charge_weights: Added to the class
Altered:
_padded_x: Charge is altered to be the cumulative sum
of the charge divided by the total charge
clustered_x: The summarization indices are added at the end of the tensor
of the charge divided by the total charge
clustered_x: The summarization indices are added at the end
of the tensor
"""
# convert the charge to the cumulative sum of the charge divided by the total charge
# convert the charge to the cumulative sum of the charge divided
# by the total charge
self._charge_weights = self._padded_x[:, :, charge_index]

self._padded_x[:, :, charge_index] = self._padded_x[
Expand Down Expand Up @@ -321,13 +328,15 @@ def add_percentile_summary(
that will be summarized with percentiles.
percentiles: percentiles used to summarize `x`. E.g. [10,50,90].
method: Method to summarize the features. E.g. "linear"
location: Location to insert the summarization indices in the clustered tensor defaults to adding at the end
location: Location to insert the summarization indices in the
clustered tensor defaults to adding at the end
Returns:
None
Adds:
None
Altered:
clustered_x: The summarization indices are added at the end of the tensor
clustered_x: The summarization indices are added at the end of
the tensor
"""
percentiles_x = np.nanpercentile(
self._padded_x[:, :, summarization_indices],
Expand All @@ -346,18 +355,21 @@ def calculate_charge_sum(self, charge_index: int) -> np.ndarray:
"""Calculate the sum of the charge."""
assert not hasattr(
self, "_charge_sum"
), "Charge sum has already been calculated, re-calculation is not allowed"
), "Charge sum has already been calculated, \
re-calculation is not allowed"
self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1)
return self._charge_sum

def calculate_charge_weights(self, charge_index: int) -> np.ndarray:
"""Calculate the weights of the charge."""
assert not hasattr(
self, "_charge_weights"
), "Charge weights have already been calculated, re-calculation is not allowed"
), "Charge weights have already been calculated, \
re-calculation is not allowed"
assert hasattr(
self, "_charge_sum"
), "Charge sum has not been calculated, please run calculate_charge_sum"
), "Charge sum has not been calculated, \
please run calculate_charge_sum"
self._charge_weights = (
self._padded_x[:, :, charge_index]
/ self._charge_sum[:, np.newaxis]
Expand All @@ -373,7 +385,8 @@ def add_sum_charge(self, location: Optional[int] = None) -> np.ndarray:
"""Add the sum of the charge to the summarization features."""
assert hasattr(
self, "_charge_sum"
), "Charge sum has not been calculated, please run calculate_charge_sum"
), "Charge sum has not been calculated, \
please run calculate_charge_sum"
self._add_column(self._charge_sum, location)
return self.clustered_x

Expand All @@ -386,8 +399,10 @@ def add_std(
"""Add the standard deviation of the column.
Args:
column: Index of the column in the padded tensor to calculate the standard deviation
location: Location to insert the standard deviation in the clustered tensor defaults to adding at the end
column: Index of the column in the padded tensor to
calculate the standard deviation
location: Location to insert the standard deviation in the
clustered tensor defaults to adding at the end
weights: Optional weights to be applied to the standard deviation
"""
self._add_column(
Expand Down

0 comments on commit 7fcc7fe

Please sign in to comment.