Skip to content

Commit

Permalink
move/use internal functions + output x
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Dec 6, 2024
1 parent 505a82e commit 4a4083b
Showing 1 changed file with 28 additions and 40 deletions.
68 changes: 28 additions & 40 deletions src/graphnet/models/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __init__(self, x: np.ndarray, cluster_columns: List[int]) -> None:
for i in range(len(self._counts)):
self._padded_x[i, : self._counts[i]] = x[: self._counts[i]]
x = x[self._counts[i] :]
return self.clustered_x

def _add_column(
self, column: np.ndarray, location: Optional[int] = None
Expand All @@ -266,6 +267,31 @@ def _add_column(
self.clustered_x, location, column, axis=1
)

def _calculate_charge_sum(self, charge_index: int) -> np.ndarray:
"""Calculate the sum of the charge."""
assert not hasattr(
self, "_charge_sum"
), "Charge sum has already been calculated, \
re-calculation is not allowed"
self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1)
return self._charge_sum

def _calculate_charge_weights(self, charge_index: int) -> np.ndarray:
"""Calculate the weights of the charge."""
assert not hasattr(
self, "_charge_weights"
), "Charge weights have already been calculated, \
re-calculation is not allowed"
assert hasattr(
self, "_charge_sum"
), "Charge sum has not been calculated, \
please run calculate_charge_sum"
self._charge_weights = (
self._padded_x[:, :, charge_index]
/ self._charge_sum[:, np.newaxis]
)
return self._charge_weights

def add_charge_threshold_summary(
self,
summarization_indices: List[int],
Expand Down Expand Up @@ -296,21 +322,8 @@ def add_charge_threshold_summary(
"""
# convert the charge to the cumulative sum of the charge divided
# by the total charge
self._charge_weights = self._padded_x[:, :, charge_index]

self._padded_x[:, :, charge_index] = self._padded_x[
:, :, charge_index
].cumsum(axis=1)

# add the charge sum to the class if it does not already exist
if not hasattr(self, "_charge_sum"):
self._charge_sum = np.nanmax(
self._padded_x[:, :, charge_index], axis=1
)

self._charge_weights = (
self._charge_weights / self._charge_sum[:, np.newaxis]
)
self._calculate_charge_sum(charge_index)
self._calculate_charge_weights(charge_index)

self._padded_x[:, :, charge_index] = (
self._padded_x[:, :, charge_index]
Expand Down Expand Up @@ -374,31 +387,6 @@ def add_percentile_summary(
self._add_column(percentiles_x, location)
return self.clustered_x

def calculate_charge_sum(self, charge_index: int) -> np.ndarray:
"""Calculate the sum of the charge."""
assert not hasattr(
self, "_charge_sum"
), "Charge sum has already been calculated, \
re-calculation is not allowed"
self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1)
return self._charge_sum

def calculate_charge_weights(self, charge_index: int) -> np.ndarray:
"""Calculate the weights of the charge."""
assert not hasattr(
self, "_charge_weights"
), "Charge weights have already been calculated, \
re-calculation is not allowed"
assert hasattr(
self, "_charge_sum"
), "Charge sum has not been calculated, \
please run calculate_charge_sum"
self._charge_weights = (
self._padded_x[:, :, charge_index]
/ self._charge_sum[:, np.newaxis]
)
return self._charge_weights

def add_counts(self, location: Optional[int] = None) -> np.ndarray:
"""Add the counts of the sensor to the summarization features."""
self._add_column(np.log10(self._counts), location)
Expand Down

0 comments on commit 4a4083b

Please sign in to comment.