From 019bb3f11b26431fa9458b466701ef9cd57ea59b Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Thu, 29 Aug 2024 15:40:33 +0900 Subject: [PATCH] add DOM summarization --- src/graphnet/models/graphs/nodes/nodes.py | 172 +++++++++++++++++++++- 1 file changed, 165 insertions(+), 7 deletions(-) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index 4e094e6be..f69b336a5 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -233,6 +233,9 @@ def __init__( time_column: str = "dom_time", charge_column: str = "charge", max_activations: Optional[int] = None, + percentiles: List[int] = [], + time_sums: List[int] = [], + counts_sum_std: List[bool] = [True, True, True], ) -> None: """Construct `NodeAsDOMTimeSeries`. @@ -242,9 +245,12 @@ def __init__( time_column: Name of time column. charge_column: Name of charge column. max_activations: Maximum number of activations to include in the time series. + percentiles: List of percentiles to calculate when summarizing the time and charge series. + time_sums: List of time sums to calculate when summarizing the time and charge series. + counts_sum_std: List of booleans indicating whether to include counts, sum and std in the summarization features. """ self._keys = keys - super().__init__(input_feature_names=self._keys) + self._id_columns = [self._keys.index(key) for key in id_columns] self._time_index = self._keys.index(time_column) try: @@ -259,11 +265,52 @@ def __init__( self._charge_index = None self._max_activations = max_activations + self._percentiles = percentiles + self._counts_sum_std = counts_sum_std + # ensure that max_activations is not set if percentiles are set and vice versa + assert not ( + max_activations is not None and len(percentiles) > 0 # type: ignore + ), "Cannot set both max_activations and percentiles" + + self._time_sums = time_sums + super().__init__(input_feature_names=self._keys) + self._time_features = [ + "time" in key for key in self._output_feature_names + ] + self._charge_features = [ + "charge" in key for key in self._output_feature_names + ] def _define_output_feature_names( self, input_feature_names: List[str] ) -> List[str]: - return input_feature_names + ["new_node_col"] + output_feature_names = deepcopy(input_feature_names) + if len(self._percentiles) > 0: + + output_feature_names[self._time_index] = "first_time" + if self._charge_index is not None: + output_feature_names[self._charge_index] = "sum_charge" + output_feature_names += list( + np.array(["counts", "mean_time", "std_time"])[ + [self._counts_sum_std] + ] + ) + time = [] + charge = [] + for pct in self._percentiles: + time.append(f"time_pct{pct}") + charge.append(f"charge_pct{pct}") + output_feature_names += time + output_feature_names += charge + if len(self._time_sums) > 0: + time = [] + charge = [] + for time_sum in self._time_sums: + time.append(f"time_sum{time_sum}") + charge.append(f"charge_sum{time_sum}") + output_feature_names += time + output_feature_names += charge + return output_feature_names def _construct_nodes(self, x: torch.Tensor) -> Data: """Construct nodes from raw node features ´x´.""" @@ -280,11 +327,6 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: # Sort by time x = x[x[:, self._time_index].argsort()] - # Undo log10 scaling so we can sum charges - x[:, charge_index] = np.power(10, x[:, charge_index]) - # Shift time to start at 0 - x[:, self._time_index] -= np.min(x[:, self._time_index]) - # Group pulses on the same DOM x = lex_sort(x, self._id_columns) unique_sensors, counts = np.unique( @@ -296,7 +338,117 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: ) sort_this = lex_sort(x=sort_this, cluster_columns=self._id_columns) unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] + counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + counts_sum = counts.sum() + + if len(self._percentiles) > 0: + new_x = np.zeros( + ( + len(counts), + len(self._id_columns) + 2 + sum(self._counts_sum_std), + ) + ) + new_x[:, self._id_columns] = unique_sensors + split_x = torch.split(torch.tensor(x), counts.tolist()) + split_x = torch.nn.utils.rnn.pad_sequence( + split_x, batch_first=True, padding_value=np.nan + ).numpy() + weights = split_x[:, :, charge_index] / np.nansum( + split_x[:, :, charge_index], axis=1 + ).reshape(-1, 1) + split_x[:, :, charge_index] = split_x[:, :, charge_index].cumsum( + axis=1 + ) + new_x[:, charge_index] = np.nanmax( + split_x[:, :, charge_index], axis=1 + ) + split_x[:, :, charge_index] = split_x[:, :, charge_index] / new_x[ + :, charge_index + ].reshape(-1, 1) + new_x[:, charge_index] = np.log10(new_x[:, charge_index]) + new_x[:, self._time_index] = np.nanmin( + split_x[:, :, self._time_index], axis=1 + ) + if self._counts_sum_std[0]: + new_x[:, -sum(self._counts_sum_std)] = np.log10(counts) + if self._counts_sum_std[1]: + new_x[:, -sum(self._counts_sum_std[1:])] = np.nansum( + split_x[:, :, self._time_index] * weights, axis=1 + ) + if self._counts_sum_std[2]: + new_x[:, -1] = np.nanstd( + split_x[:, :, self._time_index] * weights, axis=1 + ) + # Calculate the percentiles timings and charges + selections = np.argmax( + split_x[:, :, charge_index][:, :, np.newaxis] + >= (np.array(self._percentiles) / 100), + axis=1, + ) + selections += (np.arange(len(counts)) * split_x.shape[1])[ + :, np.newaxis + ] + new_x = np.column_stack( + [new_x, split_x[:, :, self._time_index].flatten()[selections]] + ) + new_x = np.column_stack( + [new_x, split_x[:, :, charge_index].flatten()[selections]] + ) + if len(self._time_sums) > 0: + # calculate the sum of the time and charge for the first n pulses + selections = np.argmax( + ( + split_x[:, :, self._time_index] + - split_x[:, 0, self._time_index][:, np.newaxis] + )[:, :, np.newaxis] + >= self._time_sums, + axis=1, + ) + selections += (np.arange(len(counts)) * split_x.shape[1])[ + :, np.newaxis + ] + new_x = np.column_stack( + [ + new_x, + split_x[:, :, self._time_index].flatten()[selections], + ] + ) + new_x = np.column_stack( + [new_x, split_x[:, :, charge_index].flatten()[selections]] + ) + + new_x[:, self._time_features] = ( + new_x[:, self._time_features] / 3.0e4 + ) + return Data(x=torch.tensor(new_x)) + + if self._max_activations is not None: + counts_mask = np.argwhere( + counts >= self._max_activations + ).flatten() + if counts_mask.size > 0: + counts_vals = counts[counts_mask] + indices = np.concatenate( + [ + [ind] * repeats + for ind, repeats in zip( + counts_mask, + np.floor( + counts_vals / self._max_activations + ).astype(int), + ) + ] + ) + counts = np.insert( + counts % self._max_activations, + indices, + self._max_activations, + ) + counts = counts[counts != 0] + assert ( + counts.sum() == counts_sum + ), f"Counts sum changed from {counts_sum} to {counts.sum()}" new_node_col = np.zeros(x.shape[0]) new_node_col[counts.cumsum()[:-1]] = 1 @@ -305,6 +457,12 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: return Data(x=torch.tensor(x)) + def _rename_features(self, feature_names: List[str]) -> List[str]: + """Rename features to include time series data.""" + new_feature_names = deepcopy(feature_names) + + return new_feature_names + class IceMixNodes(NodeDefinition): """Calculate ice properties and perform random sampling.