From 810f6c7299f2c636239b157345d49c1e1491edca Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 14:32:16 +0200 Subject: [PATCH 01/13] add merging functionality to graph_definition --- .../models/graphs/graph_definition.py | 194 +++++++++++++++++- 1 file changed, 186 insertions(+), 8 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index e384425f9..6a6ece3ee 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -34,6 +34,8 @@ def __init__( sensor_mask: Optional[List[int]] = None, string_mask: Optional[List[int]] = None, sort_by: str = None, + merge_coincident: bool = False, + merge_window: Optional[float] = None, ): """Construct ´GraphDefinition´. The ´detector´ holds. @@ -62,9 +64,16 @@ def __init__( add_inactive_sensors: If True, inactive sensors will be appended to the graph with padded pulse information. Defaults to False. sensor_mask: A list of sensor id's to be masked from the graph. Any - sensor listed here will be removed from the graph. Defaults to None. - string_mask: A list of string id's to be masked from the graph. Defaults to None. + sensor listed here will be removed from the graph. + Defaults to None. + string_mask: A list of string id's to be masked from the graph. + Defaults to None. sort_by: Name of node feature to sort by. Defaults to None. + merge_coincident: If True, raw pulses/photons arriving on the same + PMT within `merge_window` ns will be merged into a single pulse. + merge_window: The size of the time window (in ns) used to merge + coincident pulses/photons. Has no effect if `merge_coincident` is + `False`. """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -81,6 +90,13 @@ def __init__( self._string_mask = string_mask self._add_inactive_sensors = add_inactive_sensors + self._time_column = "t" + + self._n_modules = self._detector.geometry_table.shape[0] + self._merge_window = merge_window # = 4.5 + self._merge = merge_coincident + self._charge_key = "charge" + self._resolve_masks() if self._edge_definition is None: @@ -138,6 +154,18 @@ def __init__( else: self.rng = default_rng() + if merge_coincident: + if merge_window is None: + raise AssertionError( + f"Got ´merge´={merge_coincident}," + "but `merge_window` = `None`." + " Please specify a value." + ) + elif merge_window <= 0: + raise AssertionError( + f"`merge_window` must be > 0. " f"Got {merge_window}" + ) + def forward( # type: ignore self, input_features: np.ndarray, @@ -152,10 +180,11 @@ def forward( # type: ignore """Construct graph as ´Data´ object. Args: - input_features: Input features for graph construction. Shape ´[num_rows, d]´ + input_features: Input features for graph construction. + Shape ´[num_rows, d]´ input_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. - custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + custom_label_functions: Custom label functions. loss_weight_column: Name of column that holds loss weight. Defaults to None. loss_weight: Loss weight associated with event. Defaults to None. @@ -188,6 +217,12 @@ def forward( # type: ignore # Gaussian perturbation of each column if perturbation dict is given input_features = self._perturb_input(input_features) + # Merge coincident pulses + if self._merge: + input_features = self._merge_into_pulses( + input_features=input_features + ) + # Transform to pytorch tensor input_features = torch.tensor(input_features, dtype=self.dtype) @@ -244,9 +279,10 @@ def _resolve_masks(self) -> None: """Handle cases with sensor/string masks.""" if self._sensor_mask is not None: if self._string_mask is not None: - assert ( - 1 == 2 - ), """Got arguments for both `sensor_mask`and `string_mask`. Please specify only one. """ + raise AssertionError( + "Got arguments for both `sensor_mask`and " + "`string_mask`. Please specify only one." + ) if (self._sensor_mask is None) & (self._string_mask is not None): self._sensor_mask = self._convert_string_to_sensor_mask() @@ -314,8 +350,11 @@ def _geometry_table_lookup( return self._detector.geometry_table.loc[idx, :].index def _validate_input( - self, input_features: np.array, input_feature_names: List[str] + self, + input_features: np.array, + input_feature_names: List[str], ) -> None: + # node feature matrix dimension check assert input_features.shape[1] == len(input_feature_names) @@ -450,3 +489,142 @@ def _add_custom_labels( for key, fn in custom_label_functions.items(): graph[key] = fn(graph) return graph + + def _merge_into_pulses( + self, input_features: np.ndarray + ) -> Dict[str, List]: + """Merge photon attributes into pulses and add pseudo-charge.""" + photons = {} + for key in self._input_feature_names: + photons[key] = input_features[ + :, self._input_feature_names.index(key) + ].tolist() + + # Create temporary module ids based on xyz coordinates + ids = self._assign_temp_ids( + x=photons["sensor_pos_x"], + y=photons["sensor_pos_y"], + z=photons["sensor_pos_z"], + ) + + # Identify photons that needs to be merged + assert isinstance(self._merge_window, float) + idx = self._find_photons_for_merging( + t=photons["t"], ids=ids, merge_window=self._merge_window + ) + + # Merge photon attributes based on temporary ids + pulses = self._merge_to_pulses(data_dict=photons, ids_to_merge=idx) + + # Delete photons that was merged + delete_these = [] + for group in idx: + delete_these.extend(group) + + if len(delete_these) > 0: + for key in photons.keys(): + photons[key] = np.delete( + np.array(photons[key]), delete_these + ).tolist() + + # Add the pulses instead + for key in photons.keys(): + photons[key].extend(pulses[key]) + del pulses # save memory + + input_features = np.concatenate( + [ + np.array(photons[key]).reshape(-1, 1) + for key in self._input_feature_names + ], + axis=1, + ) + + return input_features + + def _merge_to_pulses( + self, data_dict: Dict[str, List], ids_to_merge: List[List[int]] + ) -> Dict[str, List]: + """Merge photon attributes into pulses according to assigned ids.""" + # Initialize a new dictionary to store the merged results + merged_dict: Dict[str, List] = {key: [] for key in data_dict.keys()} + + # Iterate over the groups of IDs to merge + for group in ids_to_merge: + for key in data_dict.keys(): + # Extract the values corresponding to the current group of IDs + values_to_merge = np.array([data_dict[key][i] for i in group]) + charges = np.array( + [data_dict[self._charge_key][i] for i in group] + ) + weights = charges / sum(charges) + # Handle numeric and non-numeric fields differently + if all( + isinstance(value, (int, float)) + for value in values_to_merge + ): + # alculate the mean for all attributes except charge + if key != self._charge_key: + merged_value = sum(values_to_merge * weights) + else: + merged_value = sum(charges) + else: + assert 1 == 1, "shouldn't reach here" + merged_dict[key].append(merged_value) + + return merged_dict + + def _assign_temp_ids( + self, x: List[float], y: List[float], z: List[float] + ) -> List[int]: + """Create a temporary module id based on xyz positions.""" + # Convert lists to a structured NumPy array + data = np.array( + list(zip(x, y, z)), + dtype=[("x", float), ("y", float), ("z", float)], + ) + + # Get the unique rows and the indices to reconstruct + # the original array with IDs + _, ids = np.unique(data, return_inverse=True, axis=0) + + return ids.tolist() + + def _find_photons_for_merging( + self, t: List[float], ids: List[int], merge_window: float + ) -> List[List[int]]: + """Identify photons that needs to be merged.""" + # Convert lists to a structured NumPy array + data = np.array( + list(zip(t, ids)), dtype=[("time", float), ("id", int)] + ) + + # Get original indices after sorting by ID first and then by time + sorted_indices = np.argsort(data, order=["id", "time"]) + sorted_data = data[sorted_indices] + + close_elements_indices = [] + current_group = [sorted_indices[0]] + + for i in range(1, len(sorted_data)): + current_value = sorted_data[i]["time"] + current_id_value = sorted_data[i]["id"] + + # Compare with the last element in the current group + if ( + current_id_value == sorted_data[i - 1]["id"] + and current_value - sorted_data[i - 1]["time"] < merge_window + ): + current_group.append(sorted_indices[i]) + else: + # If the group has more than one element, add it to the results + if len(current_group) > 1: + close_elements_indices.append(current_group) + # Start a new group + current_group = [sorted_indices[i]] + + # Append the last group if it has more than one element + if len(current_group) > 1: + close_elements_indices.append(current_group) + + return close_elements_indices From 76c8b832e4d2e0671515263c730b03311a306abd Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 14:41:23 +0200 Subject: [PATCH 02/13] generalize temp ids to xyz --- src/graphnet/models/graphs/graph_definition.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6a6ece3ee..3a986d7d2 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -501,10 +501,11 @@ def _merge_into_pulses( ].tolist() # Create temporary module ids based on xyz coordinates + xyz = self._detector.xyz ids = self._assign_temp_ids( - x=photons["sensor_pos_x"], - y=photons["sensor_pos_y"], - z=photons["sensor_pos_z"], + x=xyz[0], + y=xyz[1], + z=xyz[2], ) # Identify photons that needs to be merged From b9cf465dc57f10074b155e677299bc9fa1df0d9a Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 14:47:42 +0200 Subject: [PATCH 03/13] reference time column in Detector --- src/graphnet/models/graphs/graph_definition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 3a986d7d2..2b51bedcc 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -511,7 +511,9 @@ def _merge_into_pulses( # Identify photons that needs to be merged assert isinstance(self._merge_window, float) idx = self._find_photons_for_merging( - t=photons["t"], ids=ids, merge_window=self._merge_window + t=photons[self._detector.time_column], + ids=ids, + merge_window=self._merge_window, ) # Merge photon attributes based on temporary ids From 7d487f4ab80fc19f996920dbb22be8ac6a205761 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:01:40 +0200 Subject: [PATCH 04/13] add `sensor_time_name` as `Detector` property --- src/graphnet/models/detector/detector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 9b9fc61b0..e28bc61d3 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -33,7 +33,7 @@ def forward( # type: ignore @property def geometry_table(self) -> pd.DataFrame: """Public get method for retrieving a `Detector`s geometry table.""" - if ~hasattr(self, "_geometry_table"): + if not hasattr(self, "_geometry_table"): try: assert hasattr(self, "geometry_table_path") except AssertionError as e: @@ -60,6 +60,11 @@ def sensor_index_name(self) -> str: """Public get method for retrieving the sensor id column name.""" return self.sensor_id_column + @property + def sensor_time_name(self) -> str: + """Public get method for retrieving the sensor time column name.""" + return self.sensor_time_column + @final def _standardize( self, input_features: torch.tensor, input_feature_names: List[str] From 6779ee0fa3a6f6c8596cc1bb3327e7972e5863f5 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:06:36 +0200 Subject: [PATCH 05/13] add `sensor_time_column` to all Detectors --- src/graphnet/models/detector/icecube.py | 3 +++ src/graphnet/models/detector/liquido.py | 1 + src/graphnet/models/detector/prometheus.py | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 877a0bf65..1c9493785 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -17,6 +17,7 @@ class IceCube86(Detector): xyz = ["dom_x", "dom_y", "dom_z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "dom_time" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -57,6 +58,7 @@ class IceCubeKaggle(Detector): xyz = ["x", "y", "z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "time" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -122,6 +124,7 @@ class IceCubeUpgrade(Detector): xyz = ["dom_x", "dom_y", "dom_z"] string_id_column = "string" sensor_id_column = "sensor_id" + sensor_time_column = "dom_time" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" diff --git a/src/graphnet/models/detector/liquido.py b/src/graphnet/models/detector/liquido.py index a993b2344..4d1b79aa2 100644 --- a/src/graphnet/models/detector/liquido.py +++ b/src/graphnet/models/detector/liquido.py @@ -17,6 +17,7 @@ class LiquidO_v1(Detector): xyz = ["sipm_x", "sipm_y", "sipm_z"] string_id_column = "fiber_id" sensor_id_column = "sipm_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 90a0fbd29..5c77fbbce 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -17,6 +17,7 @@ class ORCA150SuperDense(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -47,6 +48,7 @@ class TRIDENT1211(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -77,6 +79,7 @@ class IceCubeUpgrade7(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -107,6 +110,7 @@ class WaterDemo81(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -137,6 +141,7 @@ class BaikalGVD8(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -167,6 +172,7 @@ class IceDemo81(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -197,6 +203,7 @@ class ARCA115(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -227,6 +234,7 @@ class ORCA150(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -257,6 +265,7 @@ class IceCube86Prometheus(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -287,6 +296,7 @@ class IceCubeDeepCore8(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -317,6 +327,7 @@ class IceCubeGen2(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -344,6 +355,7 @@ class PONETriangle(Detector): xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" + sensor_time_column = "t" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" From e9e3a6889b871858c455f3efd9eff813b6702260 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:14:45 +0200 Subject: [PATCH 06/13] pass new args through specific graph implementations --- src/graphnet/models/graphs/graphs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index d73b2c961..c0f0cf2db 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -1,6 +1,6 @@ """A module containing different graph representations in GraphNeT.""" -from typing import List, Optional, Dict, Union +from typing import List, Optional, Dict, Union, Any import torch from numpy.random import Generator @@ -23,6 +23,7 @@ def __init__( seed: Optional[Union[int, Generator]] = None, nb_nearest_neighbours: int = 8, columns: List[int] = [0, 1, 2], + **kwargs: Any, ) -> None: """Construct k-nn graph representation. @@ -53,6 +54,7 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs, ) @@ -70,6 +72,7 @@ def __init__( dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, + **kwargs: Any, ) -> None: """Construct isolated nodes graph representation. @@ -94,4 +97,5 @@ def __init__( input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, + **kwargs, ) From 6f993ce6de9b4d5e878524a78f9486794223dcb1 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:35:32 +0200 Subject: [PATCH 07/13] add `charge_name` as Detector property --- src/graphnet/models/detector/detector.py | 5 +++++ src/graphnet/models/graphs/graph_definition.py | 17 +++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index e28bc61d3..2802f48b3 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -65,6 +65,11 @@ def sensor_time_name(self) -> str: """Public get method for retrieving the sensor time column name.""" return self.sensor_time_column + @property + def charge_name(self) -> str: + """Public get method for retrieving the charge column name.""" + return self.charge_column + @final def _standardize( self, input_features: torch.tensor, input_feature_names: List[str] diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 2b51bedcc..29716eede 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -89,13 +89,9 @@ def __init__( self._sensor_mask = sensor_mask self._string_mask = string_mask self._add_inactive_sensors = add_inactive_sensors - - self._time_column = "t" - self._n_modules = self._detector.geometry_table.shape[0] - self._merge_window = merge_window # = 4.5 + self._merge_window = merge_window self._merge = merge_coincident - self._charge_key = "charge" self._resolve_masks() @@ -502,16 +498,17 @@ def _merge_into_pulses( # Create temporary module ids based on xyz coordinates xyz = self._detector.xyz + print(xyz) ids = self._assign_temp_ids( - x=xyz[0], - y=xyz[1], - z=xyz[2], + x=photons[xyz[0]], + y=photons[xyz[1]], + z=photons[xyz[2]], ) # Identify photons that needs to be merged - assert isinstance(self._merge_window, float) + assert isinstance(self._merge_window, (float, int)) idx = self._find_photons_for_merging( - t=photons[self._detector.time_column], + t=photons[self._detector.sensor_time_name], ids=ids, merge_window=self._merge_window, ) From fac18e622c7a977643ede6d6af3f2362a70f92ff Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:38:31 +0200 Subject: [PATCH 08/13] add `charge_column` to all Detectors --- src/graphnet/models/detector/icecube.py | 3 +++ src/graphnet/models/detector/liquido.py | 1 + src/graphnet/models/detector/prometheus.py | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 1c9493785..4fc6fa2d7 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -18,6 +18,7 @@ class IceCube86(Detector): string_id_column = "string" sensor_id_column = "sensor_id" sensor_time_column = "dom_time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -59,6 +60,7 @@ class IceCubeKaggle(Detector): string_id_column = "string" sensor_id_column = "sensor_id" sensor_time_column = "time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" @@ -125,6 +127,7 @@ class IceCubeUpgrade(Detector): string_id_column = "string" sensor_id_column = "sensor_id" sensor_time_column = "dom_time" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension of input data.""" diff --git a/src/graphnet/models/detector/liquido.py b/src/graphnet/models/detector/liquido.py index 4d1b79aa2..de2c34cc1 100644 --- a/src/graphnet/models/detector/liquido.py +++ b/src/graphnet/models/detector/liquido.py @@ -18,6 +18,7 @@ class LiquidO_v1(Detector): string_id_column = "fiber_id" sensor_id_column = "sipm_id" sensor_time_column = "t" + charge_column = None def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 5c77fbbce..a0aa5032f 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -18,6 +18,7 @@ class ORCA150SuperDense(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -49,6 +50,7 @@ class TRIDENT1211(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -80,6 +82,7 @@ class IceCubeUpgrade7(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -111,6 +114,7 @@ class WaterDemo81(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -142,6 +146,7 @@ class BaikalGVD8(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -173,6 +178,7 @@ class IceDemo81(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -204,6 +210,7 @@ class ARCA115(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -235,6 +242,7 @@ class ORCA150(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -266,6 +274,7 @@ class IceCube86Prometheus(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -297,6 +306,7 @@ class IceCubeDeepCore8(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -328,6 +338,7 @@ class IceCubeGen2(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" @@ -356,6 +367,7 @@ class PONETriangle(Detector): string_id_column = "sensor_string_id" sensor_id_column = "sensor_id" sensor_time_column = "t" + charge_column = "charge" def feature_map(self) -> Dict[str, Callable]: """Map standardization functions to each dimension.""" From 72de10e4d6487ef3989069ab8fd8d402c5fb0811 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:54:47 +0200 Subject: [PATCH 09/13] add member variable for charge in graph def --- .../models/graphs/graph_definition.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 29716eede..7e493109e 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -92,6 +92,7 @@ def __init__( self._n_modules = self._detector.geometry_table.shape[0] self._merge_window = merge_window self._merge = merge_coincident + self._charge_key = self._detector.charge_name self._resolve_masks() @@ -223,7 +224,7 @@ def forward( # type: ignore input_features = torch.tensor(input_features, dtype=self.dtype) # Standardize / Scale node features - input_features = self._detector(input_features, input_feature_names) + # input_features = self._detector(input_features, input_feature_names) # Create graph & get new node feature names graph, node_feature_names = self._node_definition(input_features) @@ -554,10 +555,18 @@ def _merge_to_pulses( for key in data_dict.keys(): # Extract the values corresponding to the current group of IDs values_to_merge = np.array([data_dict[key][i] for i in group]) - charges = np.array( - [data_dict[self._charge_key][i] for i in group] - ) - weights = charges / sum(charges) + if self._charge_key in data_dict.keys(): + charges = np.array( + [data_dict[self._charge_key][i] for i in group] + ) + weights = charges / sum(charges) + else: + self.warning_once( + f"`{self._charge_key}` not available in" + f" {self._input_feature_names}." + " Cannot do weighted sum." + ) + weights = np.array([1]) # Handle numeric and non-numeric fields differently if all( isinstance(value, (int, float)) From 8227d741d5a46758ea9b630345af5b865d36cfb4 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 15:57:58 +0200 Subject: [PATCH 10/13] add unit test for merging functionality --- tests/models/test_graph_definition.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py index 091cc0bac..00a17ecc5 100644 --- a/tests/models/test_graph_definition.py +++ b/tests/models/test_graph_definition.py @@ -17,6 +17,42 @@ from typing import List +def test_merging_coincident_pulses() -> None: + """Test merging of coincident pulses/photons.""" + input_feature_names = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z", "t"] + input_features = np.array( + [ + [1, 2, 3, 0.1], + [1, 2, 3, 2], + [2, 2, 1, 0.1], + [1, 2, 1, 0.1], + [1, 2, 1, 7], + ] + ) + + pulses = [] + for window in [0.5, 2, 8]: + graph_definition = KNNGraph( + detector=ORCA150SuperDense(), + merge_coincident=True, + merge_window=window, + input_feature_names=input_feature_names, + ) + y = graph_definition( + input_features=input_features, + input_feature_names=input_feature_names, + ) + pulses.append(y.x.shape[0]) + + # Merging window (0.5 ns) is too small to merge anything, + # so nothing should happen! + assert pulses[0] == input_features.shape[0] + # Merging window (2 ns) is large enough to merge two of the pulses + assert pulses[1] == (input_features.shape[0] - 1) + # Merging window (2 ns) is large enough to merge four of the pulses + assert pulses[2] == (input_features.shape[0] - 2) + + def test_graph_definition() -> None: """Tests the forward pass of GraphDefinition.""" # Test configuration From 6c5cf10204a5b8045c27b941b5480b8d3c8a55ad Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Wed, 11 Sep 2024 17:09:35 +0200 Subject: [PATCH 11/13] remove stray print statement --- src/graphnet/models/graphs/graph_definition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 7e493109e..6a7b6a777 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -499,7 +499,6 @@ def _merge_into_pulses( # Create temporary module ids based on xyz coordinates xyz = self._detector.xyz - print(xyz) ids = self._assign_temp_ids( x=photons[xyz[0]], y=photons[xyz[1]], From f35a04eeabbde215e687a18a37f8626bf6af5e96 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Fri, 13 Sep 2024 15:53:42 +0200 Subject: [PATCH 12/13] remove changes to DataConverter --- src/graphnet/data/dataconverter.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 99b3dad51..6bc9e9572 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,4 +1,5 @@ """Contains `DataConverter`.""" + from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type from abc import ABC @@ -260,8 +261,8 @@ def _request_event_nos(self, n_ids: int) -> List[int]: event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist() global_index.value += n_ids # type: ignore[name-defined] else: - starting_index = self._index - event_nos = np.arange(starting_index, starting_index + n_ids, 1).tolist() + start_idx = self._index + event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist() self._index += n_ids return event_nos @@ -316,7 +317,9 @@ def _update_shared_variables( self._output_files.extend(list(sorted(output_files[:]))) @final - def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None: + def merge_files( + self, files: Optional[Union[List[str], str]] = None, **kwargs: Any + ) -> None: """Merge converted files. `DataConverter` will call the `.merge_files` method in the @@ -332,7 +335,9 @@ def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None: elif files is not None: # Proceed to merge specified by user. if isinstance(files, str): - files = [files] # Cast to list if user forgot + # We shouldn't merge a single file? + self.info(f"Got just a single file {files}. Merging skipped.") + return files_to_merge = files else: # Raise error From 2f778a911ea790a35d3c2303a98aa1eb797bbcd0 Mon Sep 17 00:00:00 2001 From: RasmusOrsoe Date: Fri, 13 Sep 2024 16:14:53 +0200 Subject: [PATCH 13/13] remove unintended comment --- src/graphnet/models/graphs/graph_definition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 6a7b6a777..9c6c266c7 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -224,7 +224,7 @@ def forward( # type: ignore input_features = torch.tensor(input_features, dtype=self.dtype) # Standardize / Scale node features - # input_features = self._detector(input_features, input_feature_names) + input_features = self._detector(input_features, input_feature_names) # Create graph & get new node feature names graph, node_feature_names = self._node_definition(input_features)