diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index 9385bd33b..77669eaeb 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -51,7 +51,7 @@ def gather_cluster_sequence( Args: x: Array for clustering feature_idx: Index of the feature in `x` to - be gathered for each cluster. + be gathered for each cluster. cluster_columns: Index in `x` from which to build clusters. Returns: @@ -66,10 +66,16 @@ def gather_cluster_sequence( x[:, cluster_columns], return_counts=True, axis=0 ) # sort DOMs and pulse-counts - sort_this = np.concatenate([unique_sensors, counts.reshape(-1, 1)], axis=1) - sort_this = lex_sort(x=sort_this, cluster_columns=cluster_columns) - unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] - counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + sensor_counts = counts.reshape(-1, 1) + contingency_table = np.concatenate([unique_sensors, sensor_counts], axis=1) + sensors_in_contingency_table = np.arange(0, unique_sensors.shape[1], 1) + contingency_table = lex_sort( + x=contingency_table, cluster_columns=sensors_in_contingency_table + ) + unique_sensors = contingency_table[:, 0 : unique_sensors.shape[1]] + count_part = contingency_table[:, unique_sensors.shape[1] :] + flattened_counts = count_part.flatten() + counts = flattened_counts.astype(int) # Pad unique sensor columns with NaN's up until the maximum number of # Same pmt-pulses. Each of padded columns represents a pulse. @@ -129,8 +135,8 @@ def cluster_summarize_with_percentiles( then each row in the returned array will correspond to a DOM, and the time and charge for each DOM will be summarized by percentiles. Returned output array has dimensions - `[n_clusters, len(percentiles)*len(summarization_indices) - + len(cluster_indices)]` + `[n_clusters, + len(percentiles)*len(summarization_indices) + len(cluster_indices)]` Args: x: Array to be clustered @@ -181,9 +187,9 @@ def ice_transparency( Returns: f_scattering: Function that takes a normalized depth and returns the - corresponding normalized scattering length. + corresponding normalized scattering length. f_absorption: Function that takes a normalized depth and returns the - corresponding normalized absorption length. + corresponding normalized absorption length. """ # Data from page 31 of https://arxiv.org/pdf/1301.5361.pdf df = pd.read_parquet(