diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 9b9fc61b0..2802f48b3 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -33,7 +33,7 @@ def forward( # type: ignore @property def geometry_table(self) -> pd.DataFrame: """Public get method for retrieving a `Detector`s geometry table.""" - if ~hasattr(self, "_geometry_table"): + if not hasattr(self, "_geometry_table"): try: assert hasattr(self, "geometry_table_path") except AssertionError as e: @@ -60,6 +60,16 @@ def sensor_index_name(self) -> str: """Public get method for retrieving the sensor id column name.""" return self.sensor_id_column + @property + def sensor_time_name(self) -> str: + """Public get method for retrieving the sensor time column name.""" + return self.sensor_time_column + + @property + def charge_name(self) -> str: + """Public get method for retrieving the charge column name.""" + return self.charge_column + @final def _standardize( self, input_features: torch.tensor, input_feature_names: List[str] diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 877a0bf65..4fc6fa2d7 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -17,6 +17,8 @@ class IceCube86(Detector): xyz = ["dom_x", "dom_y", "dom_z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "dom_time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -57,6 +59,8 @@ class IceCubeKaggle(Detector): xyz = ["x", "y", "z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -122,6 +126,8 @@ class IceCubeUpgrade(Detector): xyz = ["dom_x", "dom_y", "dom_z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "dom_time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" diff --git a/src/graphnet/models/detector/liquido.py b/src/graphnet/models/detector/liquido.py index a993b2344..de2c34cc1 100644 --- a/src/graphnet/models/detector/liquido.py +++ b/src/graphnet/models/detector/liquido.py @@ -17,6 +17,8 @@ class LiquidO_v1(Detector): xyz = ["sipm_x", "sipm_y", "sipm_z"] string_id_column = "fiber_id" sensor_id_column = "sipm_id" + sensor_time_column = "t" + charge_column = None def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 90a0fbd29..a0aa5032f 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -17,6 +17,8 @@ class ORCA150SuperDense(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -47,6 +49,8 @@ class TRIDENT1211(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -77,6 +81,8 @@ class IceCubeUpgrade7(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -107,6 +113,8 @@ class WaterDemo81(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -137,6 +145,8 @@ class BaikalGVD8(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -167,6 +177,8 @@ class IceDemo81(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -197,6 +209,8 @@ class ARCA115(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -227,6 +241,8 @@ class ORCA150(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -257,6 +273,8 @@ class IceCube86Prometheus(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -287,6 +305,8 @@ class IceCubeDeepCore8(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -317,6 +337,8 @@ class IceCubeGen2(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -344,6 +366,8 @@ class PONETriangle(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 0338225b8..313112b4e 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -34,6 +34,8 @@ def __init__( sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, sort_by: str = None, + merge_coincident: bool = False, + merge_window: Optional[float] = None, repeat_labels: bool = False, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -64,7 +66,12 @@ def __init__( to the graph with padded pulse information. Defaults to False. sensor_mask: A list of sensor id's to be masked from the graph. Any sensor listed here will be removed from the graph. - Defaults to None. + Defaults to None. + merge_coincident: If True, raw pulses/photons arriving on the same + PMT within `merge_window` ns will be merged into a single pulse. + merge_window: The size of the time window (in ns) used to merge + coincident pulses/photons. Has no effect if `merge_coincident` is + `False`. Defaults to None. string_mask: A list of string id's to be masked from the graph. Defaults to None. sort_by: Name of node feature to sort by. Defaults to None. @@ -86,8 +93,11 @@ def __init__( self._sensor_mask = sensor_mask self._string_mask = string_mask self._add_inactive_sensors = add_inactive_sensors + self._n_modules = self._detector.geometry_table.shape[0] + self._merge_window = merge_window + self._merge = merge_coincident + self._charge_key = self._detector.charge_name self._repeat_labels = repeat_labels - self._resolve_masks() if self._edge_definition is None: @@ -145,6 +155,18 @@ def __init__( else: self.rng = default_rng() + if merge_coincident: + if merge_window is None: + raise AssertionError( + f"Got ´merge´={merge_coincident}," + "but `merge_window` = `None`." + " Please specify a value." + ) + elif merge_window <= 0: + raise AssertionError( + f"`merge_window` must be > 0. " f"Got {merge_window}" + ) + def forward( # type: ignore self, input_features: np.ndarray, @@ -159,10 +181,11 @@ def forward( # type: ignore """Construct graph as ´Data´ object. Args: - input_features: Input features for graph construction. Shape ´[num_rows, d]´ + input_features: Input features for graph construction. + Shape ´[num_rows, d]´ input_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. - custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + custom_label_functions: Custom label functions. loss_weight_column: Name of column that holds loss weight. Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. @@ -195,6 +218,12 @@ def forward( # type: ignore # Gaussian perturbation of each column if perturbation dict is given input_features = self._perturb_input(input_features) + # Merge coincident pulses + if self._merge: + input_features = self._merge_into_pulses( + input_features=input_features + ) + # Transform to pytorch tensor input_features = torch.tensor(input_features, dtype=self.dtype) @@ -251,9 +280,10 @@ def _resolve_masks(self) -> None: """Handle cases with sensor/string masks.""" if self._sensor_mask is not None: if self._string_mask is not None: - assert ( - 1 == 2 - ), """Got arguments for both `sensor_mask`and `string_mask`. Please specify only one. """ + raise AssertionError( + "Got arguments for both `sensor_mask`and " + "`string_mask`. Please specify only one." + ) if (self._sensor_mask is None) & (self._string_mask is not None): self._sensor_mask = self._convert_string_to_sensor_mask() @@ -321,8 +351,11 @@ def _geometry_table_lookup( return self._detector.geometry_table.loc[idx, :].index def _validate_input( - self, input_features: np.array, input_feature_names: List[str] + self, + input_features: np.array, + input_feature_names: List[str], ) -> None: + # node feature matrix dimension check assert input_features.shape[1] == len(input_feature_names) @@ -464,3 +497,153 @@ def _add_custom_labels( label = label.repeat(graph.x.shape[0], 1) graph[key] = label return graph + + def _merge_into_pulses( + self, input_features: np.ndarray + ) -> Dict[str, List]: + """Merge photon attributes into pulses and add pseudo-charge.""" + photons = {} + for key in self._input_feature_names: + photons[key] = input_features[ + :, self._input_feature_names.index(key) + ].tolist() + + # Create temporary module ids based on xyz coordinates + xyz = self._detector.xyz + ids = self._assign_temp_ids( + x=photons[xyz[0]], + y=photons[xyz[1]], + z=photons[xyz[2]], + ) + + # Identify photons that needs to be merged + assert isinstance(self._merge_window, (float, int)) + idx = self._find_photons_for_merging( + t=photons[self._detector.sensor_time_name], + ids=ids, + merge_window=self._merge_window, + ) + + # Merge photon attributes based on temporary ids + pulses = self._merge_to_pulses(data_dict=photons, ids_to_merge=idx) + + # Delete photons that was merged + delete_these = [] + for group in idx: + delete_these.extend(group) + + if len(delete_these) > 0: + for key in photons.keys(): + photons[key] = np.delete( + np.array(photons[key]), delete_these + ).tolist() + + # Add the pulses instead + for key in photons.keys(): + photons[key].extend(pulses[key]) + del pulses # save memory + + input_features = np.concatenate( + [ + np.array(photons[key]).reshape(-1, 1) + for key in self._input_feature_names + ], + axis=1, + ) + + return input_features + + def _merge_to_pulses( + self, data_dict: Dict[str, List], ids_to_merge: List[List[int]] + ) -> Dict[str, List]: + """Merge photon attributes into pulses according to assigned ids.""" + # Initialize a new dictionary to store the merged results + merged_dict: Dict[str, List] = {key: [] for key in data_dict.keys()} + + # Iterate over the groups of IDs to merge + for group in ids_to_merge: + for key in data_dict.keys(): + # Extract the values corresponding to the current group of IDs + values_to_merge = np.array([data_dict[key][i] for i in group]) + if self._charge_key in data_dict.keys(): + charges = np.array( + [data_dict[self._charge_key][i] for i in group] + ) + weights = charges / sum(charges) + else: + self.warning_once( + f"`{self._charge_key}` not available in" + f" {self._input_feature_names}." + " Cannot do weighted sum." + ) + weights = np.array([1]) + # Handle numeric and non-numeric fields differently + if all( + isinstance(value, (int, float)) + for value in values_to_merge + ): + # alculate the mean for all attributes except charge + if key != self._charge_key: + merged_value = sum(values_to_merge * weights) + else: + merged_value = sum(charges) + else: + assert 1 == 1, "shouldn't reach here" + merged_dict[key].append(merged_value) + + return merged_dict + + def _assign_temp_ids( + self, x: List[float], y: List[float], z: List[float] + ) -> List[int]: + """Create a temporary module id based on xyz positions.""" + # Convert lists to a structured NumPy array + data = np.array( + list(zip(x, y, z)), + dtype=[("x", float), ("y", float), ("z", float)], + ) + + # Get the unique rows and the indices to reconstruct + # the original array with IDs + _, ids = np.unique(data, return_inverse=True, axis=0) + + return ids.tolist() + + def _find_photons_for_merging( + self, t: List[float], ids: List[int], merge_window: float + ) -> List[List[int]]: + """Identify photons that needs to be merged.""" + # Convert lists to a structured NumPy array + data = np.array( + list(zip(t, ids)), dtype=[("time", float), ("id", int)] + ) + + # Get original indices after sorting by ID first and then by time + sorted_indices = np.argsort(data, order=["id", "time"]) + sorted_data = data[sorted_indices] + + close_elements_indices = [] + current_group = [sorted_indices[0]] + + for i in range(1, len(sorted_data)): + current_value = sorted_data[i]["time"] + current_id_value = sorted_data[i]["id"] + + # Compare with the last element in the current group + if ( + current_id_value == sorted_data[i - 1]["id"] + and current_value - sorted_data[i - 1]["time"] < merge_window + ): + current_group.append(sorted_indices[i]) + else: + # If the group has more than one element, add it to the results + if len(current_group) > 1: + close_elements_indices.append(current_group) + # Start a new group + current_group = [sorted_indices[i]] + + # Append the last group if it has more than one element + if len(current_group) > 1: + close_elements_indices.append(current_group) + + return close_elements_indices diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py index 091cc0bac..00a17ecc5 100644 --- a/tests/models/test_graph_definition.py +++ b/tests/models/test_graph_definition.py @@ -17,6 +17,42 @@ from typing import List +def test_merging_coincident_pulses() -> None: + """Test merging of coincident pulses/photons.""" + input_feature_names = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"] + input_features = np.array( + [ + [1, 2, 3, 0.1], + [1, 2, 3, 2], + [2, 2, 1, 0.1], + [1, 2, 1, 0.1], + [1, 2, 1, 7], + ] + ) + + pulses = [] + for window in [0.5, 2, 8]: + graph_definition = KNNGraph( + detector=ORCA150SuperDense(), + merge_coincident=True, + merge_window=window, + input_feature_names=input_feature_names, + ) + y = graph_definition( + input_features=input_features, + input_feature_names=input_feature_names, + ) + pulses.append(y.x.shape[0]) + + # Merging window (0.5 ns) is too small to merge anything, + # so nothing should happen! + assert pulses[0] == input_features.shape[0] + # Merging window (2 ns) is large enough to merge two of the pulses + assert pulses[1] == (input_features.shape[0] - 1) + # Merging window (2 ns) is large enough to merge four of the pulses + assert pulses[2] == (input_features.shape[0] - 2) + + def test_graph_definition() -> None: """Tests the forward pass of GraphDefinition.""" # Test configuration