diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 75be67797..7ae7881a6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,8 +36,16 @@ jobs: name: Unit tests - IceTray needs: [ check-codeclimate-credentials ] runs-on: ubuntu-latest - container: icecube/icetray:icetray-prod-v1.8.1-ubuntu20.04-X64 + container: + image: icecube/icetray:icetray-prod-v1.8.1-ubuntu20.04-X64 + options: --user root steps: + - name: install git + run: | + apt-get --yes install sudo + sudo apt update --fix-missing --yes + sudo apt upgrade --yes + sudo apt-get install --yes git-all - name: Set environment variables run: | echo "PATH=/usr/local/icetray/bin:$PATH" >> $GITHUB_ENV @@ -57,13 +65,13 @@ jobs: editable: true - name: Run unit tests and generate coverage report run: | - coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples - coverage run --source=graphnet -m pytest tests/examples/01_icetray + coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples/04_training + coverage run -a --source=graphnet -m pytest tests/examples/04_training coverage xml -o coverage.xml - #- name: Work around permission issue - # run: | - # git config --global --add safe.directory /__w/graphnet/graphnet + - name: Work around permission issue + run: | + git config --global --add safe.directory /__w/graphnet/graphnet - name: Publish code coverage uses: paambaati/codeclimate-action@v3.0.0 if: needs.check-codeclimate-credentials.outputs.has_credentials == 'true' @@ -93,14 +101,21 @@ jobs: editable: true - name: Print available disk space after graphnet install run: df -h + - name: Print packages in pip + run: | + pip show torch + pip show torch-geometric + pip show torch-cluster + pip show torch-sparse + pip show torch-scatter - name: Run unit tests and generate coverage report run: | set -o pipefail # To propagate exit code from pytest coverage run --source=graphnet -m pytest tests/ --ignore=tests/utilities --ignore=tests/data/ --ignore=tests/deployment/ --ignore=tests/examples/01_icetray/ - coverage run --source=graphnet -m pytest tests/utilities + coverage run -a --source=graphnet -m pytest tests/utilities coverage report -m - name: Print available disk space after unit tests - run: df -h + run: df -h build-macos: name: Unit tests - macOS @@ -116,8 +131,15 @@ jobs: with: editable: true hardware: "macos" + - name: Print packages in pip + run: | + pip show torch + pip show torch-geometric + pip show torch-cluster + pip show torch-sparse + pip show torch-scatter - name: Run unit tests and generate coverage report run: | set -o pipefail # To propagate exit code from pytest coverage run --source=graphnet -m pytest tests/ --ignore=tests/data/ --ignore=tests/deployment/ --ignore=tests/examples/ - coverage report -m + coverage report -m \ No newline at end of file diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 7d5ad7964..09cb04bf8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -19,6 +19,28 @@ jobs: packages: write contents: read steps: + - name: Before Clean-up + run: | + echo "Free space:" + df -h + + - name: Free Disk Space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + + # all of these default to true, but feel free to set to + # false if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: true + + - name: After Clean-up + run: | + echo "Free space:" + df -h - name: Checkout uses: actions/checkout@v3 - name: Set up QEMU diff --git a/README.md b/README.md index 8afc9d773..7a2e69af9 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,6 @@ $ conda create --name graphnet python=3.8 gcc_linux-64 gxx_linux-64 libgcc cudat $ conda activate graphnet # Optional (graphnet) $ pip install -r requirements/torch_cpu.txt -e .[develop,torch] # CPU-only torch (graphnet) $ pip install -r requirements/torch_gpu.txt -e .[develop,torch] # GPU support -(graphnet) $ pip install -r requirements/torch_macos.txt -e .[develop,torch] # On macOS ``` This should allow you to e.g. run the scripts in [examples/](./examples/) out of the box. diff --git a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml index d70de5294..523f4fa90 100644 --- a/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml +++ b/configs/datasets/dev_lvl7_robustness_muon_neutrino_0000.yml @@ -10,7 +10,7 @@ graph_definition: node_definition: arguments: {} class_name: NodesAsPulses - node_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area] + input_feature_names: [dom_x, dom_y, dom_z, dom_time, charge, rde, pmt_area] class_name: KNNGraph pulsemaps: - SRTTWOfflinePulsesDC diff --git a/configs/datasets/test_data_sqlite.yml b/configs/datasets/test_data_sqlite.yml index 689e8af31..11ea4496d 100644 --- a/configs/datasets/test_data_sqlite.yml +++ b/configs/datasets/test_data_sqlite.yml @@ -10,7 +10,7 @@ graph_definition: node_definition: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph index_column: event_no loss_weight_column: null diff --git a/configs/datasets/training_classification_example_data_sqlite.yml b/configs/datasets/training_classification_example_data_sqlite.yml index ae94420ee..3a13f8749 100644 --- a/configs/datasets/training_classification_example_data_sqlite.yml +++ b/configs/datasets/training_classification_example_data_sqlite.yml @@ -10,7 +10,7 @@ graph_definition: node_definition: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph pulsemaps: - total diff --git a/configs/datasets/training_example_data_parquet.yml b/configs/datasets/training_example_data_parquet.yml index d8bde7e30..67abca0c4 100644 --- a/configs/datasets/training_example_data_parquet.yml +++ b/configs/datasets/training_example_data_parquet.yml @@ -10,7 +10,7 @@ graph_definition: node_definition: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph pulsemaps: - total diff --git a/configs/datasets/training_example_data_sqlite.yml b/configs/datasets/training_example_data_sqlite.yml index b33a0ee0c..20c4aa8c0 100644 --- a/configs/datasets/training_example_data_sqlite.yml +++ b/configs/datasets/training_example_data_sqlite.yml @@ -10,7 +10,7 @@ graph_definition: node_definition: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph pulsemaps: - total diff --git a/configs/models/dynedge_PID_classification_example.yml b/configs/models/dynedge_PID_classification_example.yml index 57fec3e88..f9b1509c4 100644 --- a/configs/models/dynedge_PID_classification_example.yml +++ b/configs/models/dynedge_PID_classification_example.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 0.001} diff --git a/configs/models/dynedge_position_custom_scaling_example.yml b/configs/models/dynedge_position_custom_scaling_example.yml index 195695a8d..013dab592 100644 --- a/configs/models/dynedge_position_custom_scaling_example.yml +++ b/configs/models/dynedge_position_custom_scaling_example.yml @@ -17,7 +17,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph gnn: ModelConfig: diff --git a/configs/models/dynedge_position_example.yml b/configs/models/dynedge_position_example.yml deleted file mode 100644 index c82223825..000000000 --- a/configs/models/dynedge_position_example.yml +++ /dev/null @@ -1,44 +0,0 @@ -arguments: - coarsening: null - detector: - ModelConfig: - arguments: - graph_builder: - ModelConfig: - arguments: {columns: null, nb_nearest_neighbours: 8} - class_name: KNNGraphBuilder - scalers: null - class_name: IceCubeDeepCore - gnn: - ModelConfig: - arguments: - add_global_variables_after_pooling: false - dynedge_layer_sizes: null - features_subset: null - global_pooling_schemes: [min, max, mean, sum] - nb_inputs: 7 - nb_neighbours: 8 - post_processing_layer_sizes: null - readout_layer_sizes: null - class_name: DynEdge - optimizer_class: '!class torch.optim.adam Adam' - optimizer_kwargs: {eps: 0.001, lr: 1e-05} - scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau' - scheduler_config: {frequency: 1, monitor: val_loss} - scheduler_kwargs: {patience: 5} - tasks: - - ModelConfig: - arguments: - hidden_size: 128 - loss_function: - ModelConfig: - arguments: {} - class_name: MSELoss - loss_weight: null - target_labels: ["position_x", "position_y", "position_z"] - transform_inference: null - transform_prediction_and_target: null - transform_support: null - transform_target: null - class_name: PositionReconstruction -class_name: StandardModel diff --git a/configs/models/example_direction_reconstruction_model.yml b/configs/models/example_direction_reconstruction_model.yml index cb1c4d841..faf168ed5 100644 --- a/configs/models/example_direction_reconstruction_model.yml +++ b/configs/models/example_direction_reconstruction_model.yml @@ -13,7 +13,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph gnn: ModelConfig: diff --git a/configs/models/example_energy_reconstruction_model.yml b/configs/models/example_energy_reconstruction_model.yml index 827c84748..5983ef799 100644 --- a/configs/models/example_energy_reconstruction_model.yml +++ b/configs/models/example_energy_reconstruction_model.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 0.001} diff --git a/configs/models/example_vertex_position_reconstruction_model.yml b/configs/models/example_vertex_position_reconstruction_model.yml index 0522a1f2d..ce0a993c4 100644 --- a/configs/models/example_vertex_position_reconstruction_model.yml +++ b/configs/models/example_vertex_position_reconstruction_model.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] + input_feature_names: [sensor_pos_x, sensor_pos_y, sensor_pos_z, t] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: {eps: 0.001, lr: 0.001} diff --git a/examples/02_data/04_ensemble_dataset.py b/examples/02_data/04_ensemble_dataset.py index f1cc9de68..4ade95de6 100644 --- a/examples/02_data/04_ensemble_dataset.py +++ b/examples/02_data/04_ensemble_dataset.py @@ -24,7 +24,7 @@ detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=features, + input_feature_names=features, ) diff --git a/examples/05_pisa/02_make_pipeline_database.py b/examples/05_pisa/02_make_pipeline_database.py index 17e86646d..722b997f3 100644 --- a/examples/05_pisa/02_make_pipeline_database.py +++ b/examples/05_pisa/02_make_pipeline_database.py @@ -65,7 +65,7 @@ def main() -> None: detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) # Remove `interaction_time` if it exists diff --git a/requirements/torch_cpu.txt b/requirements/torch_cpu.txt index 6f68e3600..59e273288 100644 --- a/requirements/torch_cpu.txt +++ b/requirements/torch_cpu.txt @@ -1,2 +1,2 @@ --find-links https://download.pytorch.org/whl/cpu ---find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html \ No newline at end of file +--find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html \ No newline at end of file diff --git a/requirements/torch_gpu.txt b/requirements/torch_gpu.txt index 553d306e5..1f1abba3f 100644 --- a/requirements/torch_gpu.txt +++ b/requirements/torch_gpu.txt @@ -1,4 +1,4 @@ -# Contains packages recommended for functional performance +# Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html -torch==2.0.1+cu117 ---find-links https://data.pyg.org/whl/torch-2.0.0+cu117.html +torch==2.1.0+cu118 +--find-links https://data.pyg.org/whl/torch-2.1.0+cu118.html diff --git a/requirements/torch_macos.txt b/requirements/torch_macos.txt index be7a35257..3e9d75df4 100644 --- a/requirements/torch_macos.txt +++ b/requirements/torch_macos.txt @@ -1,2 +1,2 @@ --find-links https://download.pytorch.org/whl/torch_stable.html ---find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html \ No newline at end of file +--find-links https://data.pyg.org/whl/torch-2.1.0+cpu.html \ No newline at end of file diff --git a/setup.py b/setup.py index ef1b48e52..3b70233ab 100644 --- a/setup.py +++ b/setup.py @@ -47,11 +47,11 @@ "versioneer", ], "torch": [ - "torch>=2.0", + "torch>=2.1", "torch-cluster>=1.6", "torch-scatter>=2.0", "torch-sparse>=0.6", - "torch-geometric>=2.1", + "torch-geometric>=2.3", "pytorch-lightning>=2.0", ], } diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index 1eca4f6cd..fbb1ee095 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -3,3 +3,4 @@ `graphnet.data` enables converting domain-specific data to industry-standard, intermediate file formats and reading this data. """ +from .filters import I3Filter, I3FilterMask diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index dc0deabd0..41cec5eec 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -39,6 +39,7 @@ from graphnet.utilities.filesys import find_i3_files from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger +from graphnet.data.filters import I3Filter, NullSplitI3Filter if has_icecube_package(): from icecube import icetray, dataio # pyright: reportMissingImports=false @@ -107,6 +108,7 @@ def __init__( workers: int = 1, index_column: str = "event_no", icetray_verbose: int = 0, + i3_filters: List[I3Filter] = [], ): """Construct DataConverter. @@ -167,6 +169,14 @@ def __init__( self._input_file_batch_pattern = input_file_batch_pattern self._workers = workers + # I3Filters (NullSplitI3Filter is always included) + self._i3filters = [NullSplitI3Filter()] + i3_filters + + for filter in self._i3filters: + assert isinstance( + filter, I3Filter + ), f"{type(filter)} is not a subclass of I3Filter" + # Create I3Extractors self._extractors = I3ExtractorCollection(*extractors) @@ -433,6 +443,7 @@ def _extract_data(self, fileset: FileSet) -> List[OrderedDict]: except Exception as e: if "I3" in str(e): continue + # check if frame should be skipped if self._skip_frame(frame): continue @@ -555,14 +566,15 @@ def _get_output_file(self, input_file: str) -> str: return output_file def _skip_frame(self, frame: "icetray.I3Frame") -> bool: - """Check if frame should be skipped. - - Args: - frame: I3Frame to check. + """Check the user defined filters. Returns: - True if frame is a null split frame, else False. + bool: True if frame should be skipped, False otherwise. """ - if frame["I3EventHeader"].sub_event_stream == "NullSplit": - return True - return False + if self._i3filters is None: + return False # No filters defined, so we keep the frame + + for filter in self._i3filters: + if not filter(frame): + return True # keep_frame call false, skip the frame. + return False # All filter keep_frame calls true, keep the frame. diff --git a/src/graphnet/data/dataset/dataset.py b/src/graphnet/data/dataset/dataset.py index 4253788a8..c9355bbfc 100644 --- a/src/graphnet/data/dataset/dataset.py +++ b/src/graphnet/data/dataset/dataset.py @@ -282,7 +282,7 @@ def __init__( self._index_column = index_column self._truth_table = truth_table self._loss_weight_default_value = loss_weight_default_value - self._graph_definition = graph_definition + self._graph_definition = deepcopy(graph_definition) if node_truth is not None: assert isinstance(node_truth_table, str) @@ -629,8 +629,8 @@ def _create_graph( # Construct graph data object assert self._graph_definition is not None graph = self._graph_definition( - node_features=node_features, - node_feature_names=self._features[ + input_features=node_features, + input_feature_names=self._features[ 1: ], # first entry is index column truth_dicts=truth_dicts, diff --git a/src/graphnet/data/filters.py b/src/graphnet/data/filters.py new file mode 100644 index 000000000..ca83f4217 --- /dev/null +++ b/src/graphnet/data/filters.py @@ -0,0 +1,128 @@ +"""Filter classes for filtering I3-frames when converting I3-files.""" +from abc import abstractmethod +from graphnet.utilities.logging import Logger +from typing import List + +from graphnet.utilities.imports import has_icecube_package + +if has_icecube_package(): + from icecube import icetray + + +class I3Filter(Logger): + """A generic filter for I3-frames.""" + + @abstractmethod + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Return True if the frame is kept, False otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + + Returns: + bool: True if the frame is kept, False otherwise. + """ + raise NotImplementedError + + def __call__(self, frame: "icetray.I3Frame") -> bool: + """Return True if the frame passes the filter, False otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + + Returns: + bool: True if the frame passes the filter, False otherwise. + """ + pass_flag = self._keep_frame(frame) + try: + assert isinstance(pass_flag, bool) + except AssertionError: + raise TypeError( + f"Expected _pass_frame to return bool, got {type(pass_flag)}." + ) + return pass_flag + + +class NullSplitI3Filter(I3Filter): + """A filter that skips all null-split frames.""" + + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Check that frame is not a null-split frame. + + returns False if the frame is a null-split frame, True otherwise. + + Args: + frame: I3-frame + The I3-frame to check. + """ + if frame.Has("I3EventHeader"): + if frame["I3EventHeader"].sub_event_stream == "NullSplit": + return False + return True + + +class I3FilterMask(I3Filter): + """checks list of filters from the FilterMask in I3 frames.""" + + def __init__(self, filter_names: List[str], filter_any: bool = True): + """Initialize I3FilterMask. + + Args: + filter_names: List[str] + A list of filter names to check for. + filter_any: bool + standard: True + If True, the frame is kept if any of the filter names are present. + If False, the frame is kept if all of the filter names are present. + """ + self._filter_names = filter_names + self._filter_any = filter_any + + def _keep_frame(self, frame: "icetray.I3Frame") -> bool: + """Check if current frame should be kept. + + Args: + frame: I3-frame + The I3-frame to check. + """ + if "FilterMask" in frame: + if ( + self._filter_any is True + ): # Require any of the filters to pass to keep the frame + bool_list = [] + for filter_name in self._filter_names: + if filter_name not in frame["FilterMask"]: + self.warning_once( + f"FilterMask {filter_name} not found in frame. skipping filter." + ) + continue + elif frame["FilterMask"][filter].condition_passed is True: + bool_list.append(True) + else: + bool_list.append(False) + if len(bool_list) == 0: + self.warning_once( + "None of the FilterMask filters found in frame, FilterMask filters will not be applied." + ) + return any(bool_list) or len(bool_list) == 0 + else: # Require all filters to pass in order to keep the frame. + for filter_name in self._filter_names: + if filter_name not in frame["FilterMask"]: + self.warning_once( + f"FilterMask {filter_name} not found in frame, skipping filter." + ) + continue + elif frame["FilterMask"][filter].condition_passed is True: + continue # current filter passed, continue to next filter + else: + return ( + False # current filter failed so frame is skipped. + ) + return True + else: + self.warning_once( + "FilterMask not found in frame, FilterMask filters will not be applied." + ) + return True diff --git a/src/graphnet/deployment/i3modules/graphnet_module.py b/src/graphnet/deployment/i3modules/graphnet_module.py index 2c85600a3..dee0973b8 100644 --- a/src/graphnet/deployment/i3modules/graphnet_module.py +++ b/src/graphnet/deployment/i3modules/graphnet_module.py @@ -84,12 +84,12 @@ def _make_graph( ) -> Data: # py-l-i-n-t-:- -d-i-s-able=invalid-name """Process Physics I3Frame into graph.""" # Extract features - node_features = self._extract_feature_array_from_frame(frame) + input_features = self._extract_feature_array_from_frame(frame) # Prepare graph data - if len(node_features) > 0: + if len(input_features) > 0: data = self._graph_definition( - node_features=node_features, - node_feature_names=self._features, + input_features=input_features, + input_feature_names=self._features, ) return Batch.from_data_list([data]) else: diff --git a/src/graphnet/models/graphs/graph_definition.py b/src/graphnet/models/graphs/graph_definition.py index 9c4db4d47..f75f65b98 100644 --- a/src/graphnet/models/graphs/graph_definition.py +++ b/src/graphnet/models/graphs/graph_definition.py @@ -26,7 +26,7 @@ def __init__( detector: Detector, node_definition: NodeDefinition = NodesAsPulses(), edge_definition: Optional[EdgeDefinition] = None, - node_feature_names: Optional[List[str]] = None, + input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, @@ -44,7 +44,10 @@ def __init__( detector: The corresponding ´Detector´ representing the data. node_definition: Definition of nodes. Defaults to NodesAsPulses. edge_definition: Definition of edges. Defaults to None. - node_feature_names: Names of node feature columns. Defaults to None + input_feature_names: Names of each column in expected input data + that will be built into a graph. If not provided, + it is automatically assumed that all features in `Detector` is + used. dtype: data type used for node features. e.g. ´torch.float´ perturbation_dict: Dictionary mapping a feature name to a standard deviation according to which the values for this @@ -62,25 +65,30 @@ def __init__( self._node_definition = node_definition self._perturbation_dict = perturbation_dict - if node_feature_names is None: + if input_feature_names is None: # Assume all features in Detector is used. - node_feature_names = list(self._detector.feature_map().keys()) # type: ignore - self._node_feature_names = node_feature_names + input_feature_names = list(self._detector.feature_map().keys()) # type: ignore + self._input_feature_names = input_feature_names + + # Set input data column names for node definition + self._node_definition.set_output_feature_names( + self._input_feature_names + ) # Set data type self.to(dtype) # Set Input / Output dimensions self._node_definition.set_number_of_inputs( - node_feature_names=node_feature_names + input_feature_names=input_feature_names ) - self.nb_inputs = len(self._node_feature_names) + self.nb_inputs = len(self._input_feature_names) self.nb_outputs = self._node_definition.nb_outputs # Set perturbation_cols if needed if isinstance(self._perturbation_dict, dict): self._perturbation_cols = [ - self._node_feature_names.index(key) + self._input_feature_names.index(key) for key in self._perturbation_dict.keys() ] if seed is not None: @@ -97,8 +105,8 @@ def __init__( def forward( # type: ignore self, - node_features: np.ndarray, - node_feature_names: List[str], + input_features: np.ndarray, + input_feature_names: List[str], truth_dicts: Optional[List[Dict[str, Any]]] = None, custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, loss_weight_column: Optional[str] = None, @@ -109,8 +117,8 @@ def forward( # type: ignore """Construct graph as ´Data´ object. Args: - node_features: node features for graph. Shape ´[num_nodes, d]´ - node_feature_names: name of each column. Shape ´[,d]´. + input_features: Input features for graph construction. Shape ´[num_rows, d]´ + input_feature_names: name of each column. Shape ´[,d]´. truth_dicts: Dictionary containing truth labels. custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. loss_weight_column: Name of column that holds loss weight. @@ -126,23 +134,27 @@ def forward( # type: ignore """ # Checks self._validate_input( - node_features=node_features, node_feature_names=node_feature_names + input_features=input_features, + input_feature_names=input_feature_names, ) # Gaussian perturbation of each column if perturbation dict is given - node_features = self._perturb_input(node_features) + input_features = self._perturb_input(input_features) # Transform to pytorch tensor - node_features = torch.tensor(node_features, dtype=self.dtype) + input_features = torch.tensor(input_features, dtype=self.dtype) # Standardize / Scale node features - node_features = self._detector(node_features, node_feature_names) + input_features = self._detector(input_features, input_feature_names) + + # Create graph & get new node feature names + graph, node_feature_names = self._node_definition(input_features) - # Create graph - graph = self._node_definition(node_features) + # Enforce dtype + graph.x = graph.x.type(self.dtype) # Attach number of pulses as static attribute. - graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) + graph.n_pulses = torch.tensor(len(input_features), dtype=torch.int32) # Assign edges if self._edge_definition is not None: @@ -186,26 +198,26 @@ def forward( # type: ignore return graph def _validate_input( - self, node_features: np.array, node_feature_names: List[str] + self, input_features: np.array, input_feature_names: List[str] ) -> None: # node feature matrix dimension check - assert node_features.shape[1] == len(node_feature_names) + assert input_features.shape[1] == len(input_feature_names) # check that provided features for input is the same that the ´Graph´ # was instantiated with. - assert len(node_feature_names) == len( - self._node_feature_names - ), f"""Input features ({node_feature_names}) is not what + assert len(input_feature_names) == len( + self._input_feature_names + ), f"""Input features ({input_feature_names}) is not what {self.__class__.__name__} was instatiated - with ({self._node_feature_names})""" # noqa - for idx in range(len(node_feature_names)): + with ({self._input_feature_names})""" # noqa + for idx in range(len(input_feature_names)): assert ( - node_feature_names[idx] == self._node_feature_names[idx] + input_feature_names[idx] == self._input_feature_names[idx] ), f""" Order of node features in data - are not the same as expected. Got {node_feature_names} - vs. {self._node_feature_names}""" # noqa + are not the same as expected. Got {input_feature_names} + vs. {self._input_feature_names}""" # noqa - def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: + def _perturb_input(self, input_features: np.ndarray) -> np.ndarray: if isinstance(self._perturbation_dict, dict): self.warning_once( f"""Will randomly perturb @@ -213,13 +225,13 @@ def _perturb_input(self, node_features: np.ndarray) -> np.ndarray: using stds {self._perturbation_dict.values()}""" # noqa ) perturbed_features = self.rng.normal( - loc=node_features[:, self._perturbation_cols], + loc=input_features[:, self._perturbation_cols], scale=np.array( list(self._perturbation_dict.values()), dtype=float ), ) - node_features[:, self._perturbation_cols] = perturbed_features - return node_features + input_features[:, self._perturbation_cols] = perturbed_features + return input_features def _add_loss_weights( self, diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 4ae53037a..bd52eaeae 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -17,7 +17,7 @@ def __init__( self, detector: Detector, node_definition: NodeDefinition = NodesAsPulses(), - node_feature_names: Optional[List[str]] = None, + input_feature_names: Optional[List[str]] = None, dtype: Optional[torch.dtype] = torch.float, perturbation_dict: Optional[Dict[str, float]] = None, seed: Optional[Union[int, Generator]] = None, @@ -29,7 +29,7 @@ def __init__( Args: detector: Detector that represents your data. node_definition: Definition of nodes in the graph. - node_feature_names: Name of node features. + input_feature_names: Name of input feature columns. dtype: data type for node features. perturbation_dict: Dictionary mapping a feature name to a standard deviation according to which the values for this @@ -50,7 +50,7 @@ def __init__( columns=columns, ), dtype=dtype, - node_feature_names=node_feature_names, + input_feature_names=input_feature_names, perturbation_dict=perturbation_dict, seed=seed, ) diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py index 05194b61a..0119d2b98 100644 --- a/src/graphnet/models/graphs/nodes/__init__.py +++ b/src/graphnet/models/graphs/nodes/__init__.py @@ -5,4 +5,4 @@ and their features. """ -from .nodes import NodeDefinition, NodesAsPulses +from .nodes import NodeDefinition, NodesAsPulses, PercentileClusters diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index ce539ee80..fa0400b97 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -1,6 +1,6 @@ """Class(es) for building/connecting graphs.""" -from typing import List +from typing import List, Tuple, Optional from abc import abstractmethod import torch @@ -8,29 +8,53 @@ from graphnet.utilities.decorators import final from graphnet.models import Model +from graphnet.models.graphs.utils import ( + cluster_summarize_with_percentiles, + identify_indices, +) +from copy import deepcopy class NodeDefinition(Model): # pylint: disable=too-few-public-methods """Base class for graph building.""" - def __init__(self) -> None: + def __init__( + self, input_feature_names: Optional[List[str]] = None + ) -> None: """Construct `Detector`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + if input_feature_names is not None: + self.set_output_feature_names( + input_feature_names=input_feature_names + ) @final - def forward(self, x: torch.tensor) -> Data: + def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]: """Construct nodes from raw node features. Args: x: standardized node features with shape ´[num_pulses, d]´, where ´d´ is the number of node features. + node_feature_names: list of names for each column in ´x´. Returns: graph: a graph without edges + new_features_name: List of new feature names. """ - graph = self._construct_nodes(x) - return graph + graph = self._construct_nodes(x=x) + try: + self._output_feature_names + except AttributeError as e: + self.error( + f"""{self.__class__.__name__} was instantiated without + `input_feature_names` and it was not set prior to this + forward call. If you are using this class outside a + `GraphDefinition`, please instatiate + with `input_feature_names`.""" + ) # noqa + raise e + return graph, self._output_feature_names @property def nb_outputs(self) -> int: @@ -38,33 +62,152 @@ def nb_outputs(self) -> int: This the default, but may be overridden by specific inheriting classes. """ - return self.nb_inputs + return len(self._output_feature_names) @final - def set_number_of_inputs(self, node_feature_names: List[str]) -> None: + def set_number_of_inputs(self, input_feature_names: List[str]) -> None: """Return number of inputs expected by node definition. Args: - node_feature_names: name of each node feature column. + input_feature_names: name of each input feature column. + """ + assert isinstance(input_feature_names, list) + self.nb_inputs = len(input_feature_names) + + @final + def set_output_feature_names(self, input_feature_names: List[str]) -> None: + """Set output features names as a member variable. + + Args: + input_feature_names: List of column names of the input to the + node definition. + """ + self._output_feature_names = self._define_output_feature_names( + input_feature_names + ) + + @abstractmethod + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + """Construct names of output columns. + + Args: + input_feature_names: List of column names for the input data. + + Returns: + A list of column names for each column in + the node definition output. """ - assert isinstance(node_feature_names, list) - self.nb_inputs = len(node_feature_names) @abstractmethod - def _construct_nodes(self, x: torch.tensor) -> Data: + def _construct_nodes(self, x: torch.tensor) -> Tuple[Data, List[str]]: """Construct nodes from raw node features ´x´. Args: x: standardized node features with shape ´[num_pulses, d]´, where ´d´ is the number of node features. + feature_names: List of names for reach column in `x`. Identical + order of appearance. Length `d`. Returns: graph: graph without edges. + new_node_features: A list of node features names. """ class NodesAsPulses(NodeDefinition): """Represent each measured pulse of Cherenkov Radiation as a node.""" - def _construct_nodes(self, x: torch.Tensor) -> Data: + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + return input_feature_names + + def _construct_nodes(self, x: torch.Tensor) -> Tuple[Data, List[str]]: return Data(x=x) + + +class PercentileClusters(NodeDefinition): + """Represent nodes as clusters with percentile summary node features. + + If `cluster_on` is set to the xyz coordinates of DOMs + e.g. `cluster_on = ['dom_x', 'dom_y', 'dom_z']`, each node will be a + unique DOM and the pulse information (charge, time) is summarized using + percentiles. + """ + + def __init__( + self, + cluster_on: List[str], + percentiles: List[int], + add_counts: bool = True, + input_feature_names: Optional[List[str]] = None, + ) -> None: + """Construct `PercentileClusters`. + + Args: + cluster_on: Names of features to create clusters from. + percentiles: List of percentiles. E.g. `[10, 50, 90]`. + add_counts: If True, number of duplicates is added to output array. + input_feature_names: (Optional) column names for input features. + """ + self._cluster_on = cluster_on + self._percentiles = percentiles + self._add_counts = add_counts + # Base class constructor + super().__init__(input_feature_names=input_feature_names) + + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + ( + cluster_idx, + summ_idx, + new_feature_names, + ) = self._get_indices_and_feature_names( + input_feature_names, self._add_counts + ) + self._cluster_indices = cluster_idx + self._summarization_indices = summ_idx + return new_feature_names + + def _get_indices_and_feature_names( + self, + feature_names: List[str], + add_counts: bool, + ) -> Tuple[List[int], List[int], List[str]]: + cluster_idx, summ_idx, summ_names = identify_indices( + feature_names, self._cluster_on + ) + new_feature_names = deepcopy(self._cluster_on) + for feature in summ_names: + for pct in self._percentiles: + new_feature_names.append(f"{feature}_pct{pct}") + if add_counts: + # add "counts" as the last feature + new_feature_names.append("counts") + return cluster_idx, summ_idx, new_feature_names + + def _construct_nodes(self, x: torch.Tensor) -> Data: + # Cast to Numpy + x = x.numpy() + # Construct clusters with percentile-summarized features + if hasattr(self, "_summarization_indices"): + array = cluster_summarize_with_percentiles( + x=x, + summarization_indices=self._summarization_indices, + cluster_indices=self._cluster_indices, + percentiles=self._percentiles, + add_counts=self._add_counts, + ) + else: + self.error( + f"""{self.__class__.__name__} was not instatiated with + `input_feature_names` and has not been set later. + Please instantiate this class with `input_feature_names` + if you're using it outside `GraphDefinition`.""" + ) # noqa + raise AttributeError + + return Data(x=torch.tensor(array)) diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py new file mode 100644 index 000000000..ccd861783 --- /dev/null +++ b/src/graphnet/models/graphs/utils.py @@ -0,0 +1,160 @@ +"""Utility functions for construction of graphs.""" + +from typing import List, Tuple +import numpy as np + + +def lex_sort(x: np.array, cluster_columns: List[int]) -> np.ndarray: + """Sort numpy arrays according to columns on ´cluster_columns´. + + Note that `x` is sorted along the dimensions in `cluster_columns` + backwards. I.e. `cluster_columns = [0,1,2]` + means `x` is sorted along `[2,1,0]`. + + Args: + x: array to be sorted. + cluster_columns: Columns of `x` to be sorted along. + + Returns: + A sorted version of `x`. + """ + tmp_list = [] + for cluster_column in cluster_columns: + tmp_list.append(x[:, cluster_column]) + return x[np.lexsort(tuple(tmp_list)), :] + + +def gather_cluster_sequence( + x: np.ndarray, feature_idx: int, cluster_columns: List[int] +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Turn `x` into rows of clusters with sequences along columns. + + Sequences along columns are added which correspond to + gathered sequences of the feature in `x` specified by column index + `feature_idx` associated with each column. Sequences are padded with NaN to + be of same length. Dimension of clustered array is `[n_clusters, l + + len(cluster_columns)]`,where l is the largest sequence length. + + **Example**: + Suppose `x` represents a neutrino event and we have chosen to cluster on + the PMT positions and that `feature_idx` correspond to pulse time. + + The resulting array will have dimensions `[n_pmts, m + 3]` where `m` is the + maximum number of same-pmt pulses found in `x`, and `+3`for the three + spatial directions defining each cluster. + + Args: + x: Array for clustering + feature_idx: Index of the feature in `x` to + be gathered for each cluster. + cluster_columns: Index in `x` from which to build clusters. + + Returns: + array: Array with dimensions `[n_clusters, l + len(cluster_columns)]` + column_offset: Indices of the columns in `array` that defines clusters. + """ + # sort pulses according to cluster columns + x = lex_sort(x=x, cluster_columns=cluster_columns) + + # Calculate clusters and counts + unique_sensors, counts = np.unique( + x[:, cluster_columns], return_counts=True, axis=0 + ) + # sort DOMs and pulse-counts + sort_this = np.concatenate([unique_sensors, counts.reshape(-1, 1)], axis=1) + sort_this = lex_sort(x=sort_this, cluster_columns=cluster_columns) + unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] + counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + + # Pad unique sensor columns with NaN's up until the maximum number of + # Same pmt-pulses. Each of padded columns represents a pulse. + pad = np.empty((unique_sensors.shape[0], max(counts))) + pad[:] = np.nan + array = np.concatenate([unique_sensors, pad], axis=1) + column_offset = unique_sensors.shape[1] + + # Construct indices for loop + cumsum = np.zeros(len(np.cumsum(counts)) + 1) + cumsum[0] = 0 + cumsum[1:] = np.cumsum(counts) + cumsum = cumsum.astype(int) + + # Insert pulse attribute in place of NaN. + for k in range(len(counts)): + array[k, column_offset : (column_offset + counts[k])] = x[ + cumsum[k] : cumsum[k + 1], feature_idx + ] + return array, column_offset, counts + + +def identify_indices( + feature_names: List[str], cluster_on: List[str] +) -> Tuple[List[int], List[int], List[str]]: + """Identify indices for clustering and summarization.""" + features_for_summarization = [] + for feature in feature_names: + if feature not in cluster_on: + features_for_summarization.append(feature) + cluster_indices = [feature_names.index(column) for column in cluster_on] + summarization_indices = [ + feature_names.index(column) for column in features_for_summarization + ] + return cluster_indices, summarization_indices, features_for_summarization + + +def cluster_summarize_with_percentiles( + x: np.ndarray, + summarization_indices: List[int], + cluster_indices: List[int], + percentiles: List[int], + add_counts: bool, +) -> np.ndarray: + """Turn `x` into clusters with percentile summary. + + From variables specified by column indices `cluster_indices`, `x` is turned + into clusters. Information in columns of `x` specified by indices + `summarization_indices` with each cluster is summarized using percentiles. + It is assumed `x` represents a single event. + + **Example use-case**: + Suppose `x` contains raw pulses from a neutrino event where some DOMs have + multiple measurements of Cherenkov radiation. If `cluster_indices` is set + to the columns corresponding to the xyz-position of the DOMs, and the + features specified in `summarization_indices` correspond to time, charge, + then each row in the returned array will correspond to a DOM, + and the time and charge for each DOM will be summarized by percentiles. + Returned output array has dimensions + `[n_clusters, len(percentiles)*len(summarization_indices) + len(cluster_indices)]` + + Args: + x: Array to be clustered + summarization_indices: List of column indices that defines features + that will be summarized with percentiles. + cluster_indices: List of column indices on which the clusters + are constructed. + percentiles: percentiles used to summarize `x`. E.g. [10,50,90]. + + Returns: + Percentile-summarized array + """ + pct_dict = {} + for feature_idx in summarization_indices: + summarized_array, column_offset, counts = gather_cluster_sequence( + x, feature_idx, cluster_indices + ) + pct_dict[feature_idx] = np.nanpercentile( + summarized_array[:, column_offset:], percentiles, axis=1 + ).T + + for i, key in enumerate(pct_dict.keys()): + if i == 0: + array = summarized_array[:, 0:column_offset] + + array = np.concatenate([array, pct_dict[key]], axis=1) + + if add_counts: + array = np.concatenate( + [array, np.log10(counts).reshape(-1, 1)], axis=1 + ) + + return array diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml index 281bda2f4..a13f11aa2 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml @@ -19,7 +19,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml index 6cabc6985..b42e1fef8 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml index 3c0c7510a..326617c00 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml index fee57a531..c54f6ec5b 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml index 16d9ddde5..a35c0203a 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml index a49c60a22..5e88b510a 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - node_feature_names: null + input_feature_names: null class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 0f4f6895b..25a6e5107 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -94,14 +94,22 @@ def configure_optimizers(self) -> Dict[str, Any]: ) return config - def forward(self, data: Data) -> List[Union[Tensor, Data]]: + def forward( + self, data: Union[Data, List[Data]] + ) -> List[Union[Tensor, Data]]: """Forward pass, chaining model components.""" - assert isinstance(data, Data) - x = self._gnn(data) + if isinstance(data, Data): + data = [data] + x_list = [] + for d in data: + x = self._gnn(d) + x_list.append(x) + x = torch.cat(x_list, dim=0) + preds = [task(x) for task in self._tasks] return preds - def shared_step(self, batch: Data, batch_idx: int) -> Tensor: + def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. Applies the forward pass and the following loss calculation, shared @@ -111,8 +119,12 @@ def shared_step(self, batch: Data, batch_idx: int) -> Tensor: loss = self.compute_loss(preds, batch) return loss - def training_step(self, train_batch: Data, batch_idx: int) -> Tensor: + def training_step( + self, train_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: """Perform training step.""" + if isinstance(train_batch, Data): + train_batch = [train_batch] loss = self.shared_step(train_batch, batch_idx) self.log( "train_loss", @@ -125,8 +137,12 @@ def training_step(self, train_batch: Data, batch_idx: int) -> Tensor: ) return loss - def validation_step(self, val_batch: Data, batch_idx: int) -> Tensor: + def validation_step( + self, val_batch: Union[Data, List[Data]], batch_idx: int + ) -> Tensor: """Perform validation step.""" + if isinstance(val_batch, Data): + val_batch = [val_batch] loss = self.shared_step(val_batch, batch_idx) self.log( "val_loss", @@ -140,11 +156,21 @@ def validation_step(self, val_batch: Data, batch_idx: int) -> Tensor: return loss def compute_loss( - self, preds: Tensor, data: Data, verbose: bool = False + self, preds: Tensor, data: List[Data], verbose: bool = False ) -> Tensor: """Compute and sum losses across tasks.""" + data_merged = {} + target_labels_merged = list(set(self.target_labels)) + for label in target_labels_merged: + data_merged[label] = torch.cat([d[label] for d in data], dim=0) + for task in self._tasks: + if task._loss_weight is not None: + data_merged[task._loss_weight] = torch.cat( + [d[task._loss_weight] for d in data], dim=0 + ) + losses = [ - task.compute_loss(pred, data) + task.compute_loss(pred, data_merged) for task, pred in zip(self._tasks, preds) ] if verbose: @@ -154,8 +180,8 @@ def compute_loss( ), "Please reduce loss for each task separately" return torch.sum(torch.stack(losses)) - def _get_batch_size(self, data: Data) -> int: - return torch.numel(torch.unique(data.batch)) + def _get_batch_size(self, data: List[Data]) -> int: + return sum([torch.numel(torch.unique(d.batch)) for d in data]) def inference(self) -> None: """Activate inference mode.""" diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index ec3b4c461..df7c92e15 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -28,6 +28,44 @@ def collate_fn(graphs: List[Data]) -> Batch: return Batch.from_data_list(graphs) +class collator_sequence_buckleting: + """Perform the sequence bucketing for the graphs in the batch.""" + + def __init__(self, batch_splits: List[float] = [0.8]): + """Set cutting points of the different mini-batches. + + batch_splits: list of floats, each element is the fraction of the total + number of graphs. This list should not explicitly define the first and + last elements, which will always be 0 and 1 respectively. + """ + self.batch_splits = batch_splits + + def __call__(self, graphs: List[Data]) -> Batch: + """Execute sequence bucketing on the input list of graphs. + + Args: + graphs: A list of Data objects representing the input graphs. + + Returns: + A list of Batch objects, each containing a mini-batch of the input + graphs sorted by their number of pulses. + """ + graphs = [g for g in graphs if g.n_pulses > 1] + graphs.sort(key=lambda x: x.n_pulses) + batch_list = [] + + for minp, maxp in zip( + [0] + self.batch_splits, self.batch_splits + [1] + ): + min_idx = int(minp * len(graphs)) + max_idx = int(maxp * len(graphs)) + this_graphs = graphs[min_idx:max_idx] + if len(this_graphs) > 0: + this_batch = Batch.from_data_list(this_graphs) + batch_list.append(this_batch) + return batch_list + + # @TODO: Remove in favour of DataLoader{,.from_dataset_config} def make_dataloader( db: str, diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 64fcd85c6..480f11d4d 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -115,7 +115,7 @@ def test_dataset(backend: str) -> None: detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) # Constructor DataConverter instance @@ -168,7 +168,7 @@ def test_datasetquery_table(backend: str) -> None: detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) # Constructor DataConverter instance pulsemap = "SRTInIcePulses" @@ -220,7 +220,7 @@ def test_parquet_to_sqlite_converter() -> None: detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) # Perform conversion from I3 to `backend` database_name = FILE_NAME + "_from_parquet" diff --git a/tests/models/test_graph_definition.py b/tests/models/test_graph_definition.py index bf16d7853..ec6c75e24 100644 --- a/tests/models/test_graph_definition.py +++ b/tests/models/test_graph_definition.py @@ -27,7 +27,7 @@ def test_graph_definition() -> None: detector=Prometheus(), perturbation_dict=perturbation_dict, seed=seed ) original_output = graph_definition( - node_features=deepcopy(mock_data), node_feature_names=features + input_features=deepcopy(mock_data), input_feature_names=features ) for _ in range(n_reps): @@ -42,11 +42,11 @@ def test_graph_definition() -> None: ) data = graph_definition( - node_features=deepcopy(mock_data), node_feature_names=features + input_features=deepcopy(mock_data), input_feature_names=features ) perturbed_data = graph_definition_perturbed( - node_features=deepcopy(mock_data), node_feature_names=features + input_features=deepcopy(mock_data), input_feature_names=features ) assert ~torch.equal(data.x, perturbed_data.x) # should not be equal. diff --git a/tests/models/test_node_definition.py b/tests/models/test_node_definition.py new file mode 100644 index 000000000..4c199abd6 --- /dev/null +++ b/tests/models/test_node_definition.py @@ -0,0 +1,80 @@ +"""Unit tests for node definitions.""" +import numpy as np +import pandas as pd +import sqlite3 +import torch +from graphnet.models.graphs.nodes import PercentileClusters +from graphnet.constants import EXAMPLE_DATA_DIR + + +def test_percentile_cluster() -> None: + """Test that percentiles outputted by PercentileCluster. + + Here we check that it matches percentiles obtained from "traditional" ways. + """ + # definitions + percentiles = [0, 10, 50, 90, 100] + database = f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db" + # Grab first event in database + with sqlite3.connect(database) as con: + query = "select event_no from mc_truth limit 1" + event_no = pd.read_sql(query, con) + query = f'select sensor_pos_x, sensor_pos_y, sensor_pos_z, t from total where event_no = {str(event_no["event_no"][0])}' + df = pd.read_sql(query, con) + + # Save original feature names, create variables. + original_features = list(df.columns) + x = np.array(df) + tensor = torch.tensor(x) + + # Construct node definition + # This defines each DOM as a cluster, and will summarize pulses seen by + # DOMs using percentiles. + node_definition = PercentileClusters( + cluster_on=["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"], + percentiles=percentiles, + input_feature_names=original_features, + ) + + # Apply node definition to torch tensor with raw pulses + graph, new_features = node_definition(tensor) + x_tilde = graph.x.numpy() + + # Calculate percentiles "the normal way" and compare that output of + # node definition match. + + unique_doms = ( + df.groupby(["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"]) + .size() + .reset_index() + ) + for i in range(len(unique_doms)): + idx_original = ( + (df["sensor_pos_x"] == unique_doms["sensor_pos_x"][i]) + & ((df["sensor_pos_y"] == unique_doms["sensor_pos_y"][i])) + & (df["sensor_pos_z"] == unique_doms["sensor_pos_z"][i]) + ) + idx_tilde = ( + ( + x_tilde[:, new_features.index("sensor_pos_x")] + == unique_doms["sensor_pos_x"][i] + ) + & ( + x_tilde[:, new_features.index("sensor_pos_y")] + == unique_doms["sensor_pos_y"][i] + ) + & ( + x_tilde[:, new_features.index("sensor_pos_z")] + == unique_doms["sensor_pos_z"][i] + ) + ) + for percentile in percentiles: + pct_idx = new_features.index(f"t_pct{percentile}") + try: + assert np.isclose( + x_tilde[idx_tilde, pct_idx], + np.percentile(df.loc[idx_original, "t"], percentile), + ) + except AssertionError as e: + print(f"Percentile {percentile} does not match.") + raise e diff --git a/tests/models/test_task.py b/tests/models/test_task.py index 68e014f33..bfadb6263 100644 --- a/tests/models/test_task.py +++ b/tests/models/test_task.py @@ -18,7 +18,7 @@ def test_transform_prediction_and_target() -> None: detector=IceCube86(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) gnn = DynEdge( nb_inputs=graph_definition.nb_outputs, diff --git a/tests/training/test_dataloader_utilities.py b/tests/training/test_dataloader_utilities.py index 0fdaccf60..423b2f34b 100644 --- a/tests/training/test_dataloader_utilities.py +++ b/tests/training/test_dataloader_utilities.py @@ -22,7 +22,7 @@ detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) diff --git a/tests/utilities/test_dataset_config.py b/tests/utilities/test_dataset_config.py index 5f7de5b6a..ca906d659 100644 --- a/tests/utilities/test_dataset_config.py +++ b/tests/utilities/test_dataset_config.py @@ -30,7 +30,7 @@ detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) diff --git a/tests/utilities/test_model_config.py b/tests/utilities/test_model_config.py index 8979f0255..59eb6343a 100644 --- a/tests/utilities/test_model_config.py +++ b/tests/utilities/test_model_config.py @@ -49,7 +49,7 @@ def test_complete_model_config(path: str = "/tmp/complete_model.yml") -> None: detector=IceCubeDeepCore(), node_definition=NodesAsPulses(), nb_nearest_neighbours=8, - node_feature_names=FEATURES.DEEPCORE, + input_feature_names=FEATURES.DEEPCORE, ) gnn = DynEdge( nb_inputs=graph_definition.nb_outputs,