diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index 85ea94d9d..9587010df 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -247,6 +247,7 @@ def __init__(self, x: np.ndarray, cluster_columns: List[int]) -> None: for i in range(len(self._counts)): self._padded_x[i, : self._counts[i]] = x[: self._counts[i]] x = x[self._counts[i] :] + return self.clustered_x def _add_column( self, column: np.ndarray, location: Optional[int] = None @@ -266,6 +267,31 @@ def _add_column( self.clustered_x, location, column, axis=1 ) + def _calculate_charge_sum(self, charge_index: int) -> np.ndarray: + """Calculate the sum of the charge.""" + assert not hasattr( + self, "_charge_sum" + ), "Charge sum has already been calculated, \ + re-calculation is not allowed" + self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1) + return self._charge_sum + + def _calculate_charge_weights(self, charge_index: int) -> np.ndarray: + """Calculate the weights of the charge.""" + assert not hasattr( + self, "_charge_weights" + ), "Charge weights have already been calculated, \ + re-calculation is not allowed" + assert hasattr( + self, "_charge_sum" + ), "Charge sum has not been calculated, \ + please run calculate_charge_sum" + self._charge_weights = ( + self._padded_x[:, :, charge_index] + / self._charge_sum[:, np.newaxis] + ) + return self._charge_weights + def add_charge_threshold_summary( self, summarization_indices: List[int], @@ -296,21 +322,8 @@ def add_charge_threshold_summary( """ # convert the charge to the cumulative sum of the charge divided # by the total charge - self._charge_weights = self._padded_x[:, :, charge_index] - - self._padded_x[:, :, charge_index] = self._padded_x[ - :, :, charge_index - ].cumsum(axis=1) - - # add the charge sum to the class if it does not already exist - if not hasattr(self, "_charge_sum"): - self._charge_sum = np.nanmax( - self._padded_x[:, :, charge_index], axis=1 - ) - - self._charge_weights = ( - self._charge_weights / self._charge_sum[:, np.newaxis] - ) + self._calculate_charge_sum(charge_index) + self._calculate_charge_weights(charge_index) self._padded_x[:, :, charge_index] = ( self._padded_x[:, :, charge_index] @@ -374,31 +387,6 @@ def add_percentile_summary( self._add_column(percentiles_x, location) return self.clustered_x - def calculate_charge_sum(self, charge_index: int) -> np.ndarray: - """Calculate the sum of the charge.""" - assert not hasattr( - self, "_charge_sum" - ), "Charge sum has already been calculated, \ - re-calculation is not allowed" - self._charge_sum = self._padded_x[:, :, charge_index].sum(axis=1) - return self._charge_sum - - def calculate_charge_weights(self, charge_index: int) -> np.ndarray: - """Calculate the weights of the charge.""" - assert not hasattr( - self, "_charge_weights" - ), "Charge weights have already been calculated, \ - re-calculation is not allowed" - assert hasattr( - self, "_charge_sum" - ), "Charge sum has not been calculated, \ - please run calculate_charge_sum" - self._charge_weights = ( - self._padded_x[:, :, charge_index] - / self._charge_sum[:, np.newaxis] - ) - return self._charge_weights - def add_counts(self, location: Optional[int] = None) -> np.ndarray: """Add the counts of the sensor to the summarization features.""" self._add_column(np.log10(self._counts), location)