diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 282dd2fde..141ba41b4 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -23,8 +23,8 @@ Appendix: ## 1. Introduction -GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using graph neural networks (GNNs). -The framework builds on [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [PyTorch-Lightning](https://www.pytorchlightning.ai/index.html), but attempts to abstract away many of the lower-level implementation details and instead provide simple, high-level components that makes it easy and fast for physicists to use GNNs in their research. +GraphNeT is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning (DL). +The framework builds on [PyTorch](https://pytorch.org/), [PyG](https://www.pyg.org/), and [PyTorch-Lightning](https://www.pytorchlightning.ai/index.html), but attempts to abstract away many of the lower-level implementation details and instead provide simple, high-level components that makes it easy and fast for physicists to use DL in their research. This tutorial aims to introduce the various elements of `GraphNeT` to new users. It will go through the main modules, explain some of the structure and design behind these, and show concrete code examples. @@ -41,13 +41,13 @@ If you want to get your hands dirty right away, feel free to skip to [Section 3 ## 2. Overview of GraphNeT The main modules of GraphNeT are, in the order that you will likely use them: -- [`graphnet.data`](src/graphnet/data): For converting domain-specific data (i.e., I3 in the case of IceCube) to generic, intermediate file formats (e.g., SQLite or Parquet) using [`DataConverter`](src/graphnet/data/dataconverter.py); and for reading data as graphs from these intermediate files when training GNNs using [`Dataset`](src/graphnet/data/dataset.py), and its format-specific subclasses and [`DataLoader`](src/graphnet/data/dataloader.py). -- [`graphnet.models`](src/graphnet/models): For building GNNs to perform a variety of physics tasks. The base [`Model`](src/graphnet/models/model.py) class provides common interfaces for training and inference, as well as for model management (saving, loading, configs, etc.). This can be subclassed to build and train any GNN using GraphNeT functionality. The more specialised [`StandardModel`](src/graphnet/models/standard_model.py) provides a simple way to create a standard type of `Model` with a fixed structure. This type of model is composed of the following components, in sequence: +- [`graphnet.data`](src/graphnet/data): For converting domain-specific data (i.e., I3 in the case of IceCube) to generic, intermediate file formats (e.g., SQLite or Parquet) using [`DataConverter`](src/graphnet/data/dataconverter.py); and for reading data as graphs from these intermediate files when training using [`Dataset`](src/graphnet/data/dataset.py), and its format-specific subclasses and [`DataLoader`](src/graphnet/data/dataloader.py). +- [`graphnet.models`](src/graphnet/models): For building models to perform a variety of physics tasks. The base [`Model`](src/graphnet/models/model.py) class provides common interfaces for training and inference, as well as for model management (saving, loading, configs, etc.). This can be subclassed to build and train any model using GraphNeT functionality. The more specialised [`StandardModel`](src/graphnet/models/standard_model.py) provides a simple way to create a standard type of `Model` with a fixed structure. This type of model is composed of the following components, in sequence: - [`GraphDefinition`](src/graphnet/models/graphs/graph_definition.py): A single, self-contained module that handles all processing from raw data to graph representation. It consists of the following sub-modules in sequence: - [`Detector`](src/graphnet/models/detector/detector.py): For handling detector-specific preprocessing of data. Currently, this module provides standardization of experiment specific input data. - - [`NodeDefinition`](src/graphnet/models/graphs/nodes/nodes.py): A swapable module that defines what a node represents. In charge of transforming the collection of standardized Cherenkov pulses associated with a triggered event into a node representation of choice. It is the choice in this module that defines if nodes represents single Cherenkov pulses, DOMs, entire strings or something completely different. **Note**: You can create `NodeDefinitions` that represents the data as sequences, images or whatever you fancy, making GraphNeT compatible with any deep learning paradigm, such as CNNs, Transformers etc. - - [`EdgeDefinition`](src/graphnet/models/graphs/edges/edges.py) (Optional): A module that defines how edges are drawn between your nodes. This could be connecting the _N_ nearest neighbours of each node or connecting all nodes within a radius of _R_ meters of each other. - - [`GNN`](src/graphnet/models/gnn/gnn.py): For implementing the actual, learnable GNN layers. These are the components of GraphNeT that are actually being trained, and the architecture and complexity of these are central to the performance and optimisation on the physics/learning task being performed. For now, we provide a few different example architectures, e.g., [`DynEdge`](src/graphnet/models/gnn/convnet.py) and [`ConvNet`](src/graphnet/models/gnn/convnet.py), but in principle any GNN architecture could be implemented here — and we encourage you to contribute your favourite! + - [`NodeDefinition`](src/graphnet/models/graphs/nodes/nodes.py): A swapable module that defines what a node/row represents. In charge of transforming the collection of standardized Cherenkov pulses associated with a triggered event into a node/row representation of choice. It is the choice in this module that defines if nodes/rows represents single Cherenkov pulses, DOMs, entire strings or something completely different. **Note**: You can create `NodeDefinitions` that represents the data as sequences, images or whatever you fancy, making GraphNeT compatible with any deep learning paradigm, such as CNNs, Transformers etc. + - [`EdgeDefinition`](src/graphnet/models/graphs/edges/edges.py) (Optional): A module that defines how edges are drawn between your nodes. This could be connecting the _N_ nearest neighbours of each node or connecting all nodes within a radius of _R_ meters of each other. For methods that does not directly use edges in their data representations, this module can be skipped. + - [`backbone`](src/graphnet/models/gnn/gnn.py): For implementing the actual model architecture. These are the components of GraphNeT that are actually being trained, and the architecture and complexity of these are central to the performance and optimisation on the physics/learning task being performed. For now, we provide a few different example architectures, e.g., [`DynEdge`](src/graphnet/models/gnn/convnet.py) and [`ConvNet`](src/graphnet/models/gnn/convnet.py), but in principle any DL architecture could be implemented here — and we encourage you to contribute your favourite! - [`Task`](src/graphnet/models/task/task.py): For choosing a certain physics/learning task or tasks with respect to which the model should be trained. We provide a number of common [reconstruction](src/grapnet/models/task/reconstruction.py) (`DirectionReconstructionWithKappa` and `EnergyReconstructionWithUncertainty`) and [classification](src/grapnet/models/task/classification.py) (e.g., `BinaryClassificationTask` and `MulticlassClassificationTask`) tasks, but we encourage you to expand on these with new, more specialised tasks appropriate to your physics use case. For now, `Task` instances also require an appropriate [`LossFunction`](src/graphnet/training/loss_functions.py) to specify how the models should be trained (see below). These components are packaged in a particularly simple way in `StandardModel`, but they are not specific to it. @@ -61,7 +61,7 @@ In the following sections, we will go through some of the main elements of Graph ## 3. Data -You will probably want to train and apply GNN models on your own physics data. There are some pointers to this in Sections [A - Interfacing your data with GraphNeT](#a-interfacing-your-data-with-graphnet) and [B - Converting your data to a supported format](#b-converting-your-data-to-a-supported-format) below. +You will probably want to train and apply models on your own physics data. There are some pointers to this in Sections [A - Interfacing your data with GraphNeT](#a-interfacing-your-data-with-graphnet) and [B - Converting your data to a supported format](#b-converting-your-data-to-a-supported-format) below. However, to get you started, GraphNeT comes with a tiny open-source data sample. You will not be able to train a fully-deployable model with such low statistics, but it's sufficient to introduce the code and start running a few examples. @@ -391,13 +391,13 @@ That is, conceptually, > Data → `Model` → Predictions -You can subclass the `Model` class to create any model implementation using GraphNeT components (such as instances of, e.g., the `GraphDefinition`, `GNN`, and `Task` classes) along with PyTorch and PyG functionality. +You can subclass the `Model` class to create any model implementation using GraphNeT components (such as instances of, e.g., the `GraphDefinition`, `Backbone`, and `Task` classes) along with PyTorch and PyG functionality. All `Model`s that are applicable to the same detector configuration, regardless of how the `Model`s themselves are implemented, should be able to act on the same graph (`torch_geometric.data.Data`) objects, thereby making them interchangeable and directly comparable. ### The `StandardModel` class The simplest way to define a `Model` in GraphNeT is through the `StandardModel` subclass. -This is uniquely defined based on one each of[`GraphDefinition`](), [`GNN`](https://graphnet-team.github.io/graphnet/api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn), and one or more [`Task`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.task.html#module-graphnet.models.task.task)s.Each of these components will be a problem-specific instance of these parent classes.This structure guarantees modularity and reuseability. For example, the only adaptation needed to run a `Model` made for IceCube on a different experiment — say, KM3NeT — would be to switch out the `Detector` component in `GraphDefinition` representing IceCube with one that represents KM3NeT. Similarly, a `Model` developed for [`EnergyReconstruction`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction) can be put to work on a different problem, e.g., [`DirectionReconstructionWithKappa`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa), by switching out just the [`Task`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.task.html#module-graphnet.models.task.task) component. +This is uniquely defined based on one each of[`GraphDefinition`](), [`Backbone`](https://graphnet-team.github.io/graphnet/api/graphnet.models.gnn.gnn.html#module-graphnet.models.gnn.gnn), and one or more [`Task`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.task.html#module-graphnet.models.task.task)s.Each of these components will be a problem-specific instance of these parent classes.This structure guarantees modularity and reuseability. For example, the only adaptation needed to run a `Model` made for IceCube on a different experiment — say, KM3NeT — would be to switch out the `Detector` component in `GraphDefinition` representing IceCube with one that represents KM3NeT. Similarly, a `Model` developed for [`EnergyReconstruction`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.EnergyReconstruction) can be put to work on a different problem, e.g., [`DirectionReconstructionWithKappa`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.reconstruction.html#graphnet.models.task.reconstruction.DirectionReconstructionWithKappa), by switching out just the [`Task`](https://graphnet-team.github.io/graphnet/api/graphnet.models.task.task.html#module-graphnet.models.task.task) component. GraphNeT comes with many pre-defined components that you can simply import and use out-of-the-box. @@ -427,12 +427,12 @@ graph_definition = KNNGraph( node_definition=NodesAsPulses(), nb_nearest_neighbours=8, ) -architecture = DynEdge( +backbone = DynEdge( nb_inputs=detector.nb_outputs, global_pooling_schemes=["min", "max", "mean"], ) task = ZenithReconstructionWithKappa( - hidden_size=architecture.nb_outputs, + hidden_size=backbone.nb_outputs, target_labels="injection_zenith", loss_function=VonMisesFisher2DLoss(), ) @@ -440,14 +440,14 @@ task = ZenithReconstructionWithKappa( # Construct the Model model = StandardModel( graph_definition=graph_definition, - architecture=architecture, + backbone=backbone, tasks=[task], ) ``` -**Note:** We're adding the argument `global_pooling_schemes=["min", "max", "mean"],` to the `GNN` component, since by default, no global pooling is performed. +**Note:** We're adding the argument `global_pooling_schemes=["min", "max", "mean"],` to the `Backbone` component, since by default, no global pooling is performed by this specific method. This is relevant when doing node-/hit-level predictions. -However, when doing graph-/event-level predictions, we want to perform a global pooling after the last layer of the `GNN`. +However, when doing graph-/event-level predictions, we want to perform a global pooling after the last layer of this`GNN`. ### Creating reproducible `Model`s using `ModelConfig` @@ -540,7 +540,7 @@ class_name: StandardModel ## 6. Training `Model`s and tracking experiments - `Model`s in GraphNeT comes with a powerful in-built [`Model.fit`](https://graphnet-team.github.io/graphnet/api/graphnet.models.model.html#graphnet.models.model.Model.fit) method that reduces the training of GNNs on neutrino telescopes to a syntax that is similar to that of `sklearn`: + `Model`s in GraphNeT comes with a powerful in-built [`Model.fit`](https://graphnet-team.github.io/graphnet/api/graphnet.models.model.html#graphnet.models.model.Model.fit) method that reduces the training of models on neutrino telescopes to a syntax that is similar to that of `sklearn`: ```python model = Model(...) @@ -764,7 +764,7 @@ Similarly, every class inheriting from `Logger` can use the same methods as, e.g ## A. Interfacing your data with GraphNeT -GraphNeT currently supports two data format — Parquet and SQLite — and you must therefore provide your data in either of these formats for training a GNN. +GraphNeT currently supports two data format — Parquet and SQLite — and you must therefore provide your data in either of these formats for training a `Model`. This is done using the `DataConverter` class. Performing this conversion into one of the two supported formats can be a somewhat time-consuming task, but it is only done once, and then you are free to perform all of the training and optimisation you want. diff --git a/README.md b/README.md index 7a2e69af9..6e81bd95c 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ ## :rocket: About -**GraphNeT** is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using graph neural networks (GNNs). GraphNeT makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques. +**GraphNeT** is an open-source Python framework aimed at providing high quality, user friendly, end-to-end functionality to perform reconstruction tasks at neutrino telescopes using deep learning (DL). GraphNeT makes it fast and easy to train complex models that can provide event reconstruction with state-of-the-art performance, for arbitrary detector configurations, with inference times that are orders of magnitude faster than traditional reconstruction techniques. Feel free to join the [GraphNeT Slack group](https://join.slack.com/t/graphnet-team/signup)! @@ -128,7 +128,7 @@ You can use any of the following Docker image tags: ## :ringed_planet: Use cases -Below is an incomplete list of potential use cases for GNNs in neutrino telescopes. +Below is an incomplete list of potential use cases for Deep Learning in neutrino telescopes. These are categorised as either "Reconstruction challenges" that are considered common and that may benefit several experiments physics analyses; and those same "Experiments" and "Physics analyses".
diff --git a/assets/identity/graphnet-logo-and-wordmark.png b/assets/identity/graphnet-logo-and-wordmark.png index 087843a3d..08b62cae5 100644 Binary files a/assets/identity/graphnet-logo-and-wordmark.png and b/assets/identity/graphnet-logo-and-wordmark.png differ diff --git a/data/geometry_tables/icecube/icecube86.parquet b/data/geometry_tables/icecube/icecube86.parquet index 776949d61..94beceb94 100644 Binary files a/data/geometry_tables/icecube/icecube86.parquet and b/data/geometry_tables/icecube/icecube86.parquet differ diff --git a/data/geometry_tables/icecube/icecube_upgrade.parquet b/data/geometry_tables/icecube/icecube_upgrade.parquet index cadbd967a..18aa7e8ad 100644 Binary files a/data/geometry_tables/icecube/icecube_upgrade.parquet and b/data/geometry_tables/icecube/icecube_upgrade.parquet differ diff --git a/data/ice_properties/ice_transparency.parquet b/data/ice_properties/ice_transparency.parquet new file mode 100644 index 000000000..2a4068a58 Binary files /dev/null and b/data/ice_properties/ice_transparency.parquet differ diff --git a/data/tests/parquet/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.parquet b/data/tests/parquet/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.parquet index d9dfb6e98..f633a10c3 100644 Binary files a/data/tests/parquet/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.parquet and b/data/tests/parquet/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.parquet differ diff --git a/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db b/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db index befb894e7..67d190132 100644 Binary files a/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db and b/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db differ diff --git a/data/tests/sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db b/data/tests/sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db index 84369a46c..d05a12ecc 100644 Binary files a/data/tests/sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db and b/data/tests/sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db differ diff --git a/examples/01_icetray/01_convert_i3_files.py b/examples/01_icetray/01_convert_i3_files.py index 88dcf714a..9f0795cb1 100644 --- a/examples/01_icetray/01_convert_i3_files.py +++ b/examples/01_icetray/01_convert_i3_files.py @@ -1,9 +1,10 @@ """Example of converting I3-files to SQLite and Parquet.""" import os +from glob import glob from graphnet.constants import EXAMPLE_OUTPUT_DIR, TEST_DATA_DIR -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, I3FeatureExtractorIceCube86, I3RetroExtractor, @@ -41,17 +42,22 @@ def main_icecube86(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/ic86" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/oscNext_genie_level7_v02/*GeoCalib*" + )[0] - converter: DataConverter = CONVERTER_CLASS[backend]( - [ + converter = CONVERTER_CLASS[backend]( + extractors=[ I3FeatureExtractorIceCube86("SRTInIcePulses"), I3TruthExtractor(), ], - outdir, + outdir=outdir, + gcd_rescue=gcd_rescue, + workers=1, ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() def main_icecube_upgrade(backend: str) -> None: @@ -61,25 +67,25 @@ def main_icecube_upgrade(backend: str) -> None: inputs = [f"{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998"] outdir = f"{EXAMPLE_OUTPUT_DIR}/convert_i3_files/upgrade" + gcd_rescue = glob( + "{TEST_DATA_DIR}/i3/upgrade_genie_step4_140028_000998/*GeoCalib*" + )[0] workers = 1 converter: DataConverter = CONVERTER_CLASS[backend]( - [ + extractors=[ I3TruthExtractor(), I3RetroExtractor(), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_mDOM"), I3FeatureExtractorIceCubeUpgrade("I3RecoPulseSeriesMap_DEgg"), ], - outdir, + outdir=outdir, workers=workers, - # nb_files_to_batch=10, - # sequential_batch_pattern="temp_{:03d}", - # input_file_batch_pattern="[A-Z]{1}_[0-9]{5}*.i3.zst", - icetray_verbose=1, + gcd_rescue=gcd_rescue, ) converter(inputs) if backend == "sqlite": - converter.merge_files(os.path.join(outdir, "merged")) + converter.merge_files() if __name__ == "__main__": diff --git a/examples/01_icetray/02_compare_sqlite_and_parquet.py b/examples/01_icetray/02_compare_sqlite_and_parquet.py index 99250d4b0..d3874c5f2 100644 --- a/examples/01_icetray/02_compare_sqlite_and_parquet.py +++ b/examples/01_icetray/02_compare_sqlite_and_parquet.py @@ -7,7 +7,7 @@ from graphnet.data.sqlite import SQLiteDataConverter from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import SQLiteDataset, ParquetDataset -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/examples/01_icetray/03_i3_deployer_example.py b/examples/01_icetray/03_i3_deployer_example.py index f55aa769c..28d73c00d 100644 --- a/examples/01_icetray/03_i3_deployer_example.py +++ b/examples/01_icetray/03_i3_deployer_example.py @@ -10,7 +10,7 @@ PRETRAINED_MODEL_DIR, ) from graphnet.data.constants import FEATURES, TRUTH -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.utilities.argparse import ArgumentParser diff --git a/examples/01_icetray/04_i3_module_in_native_icetray_example.py b/examples/01_icetray/04_i3_module_in_native_icetray_example.py index 74da5e499..09e9b358e 100644 --- a/examples/01_icetray/04_i3_module_in_native_icetray_example.py +++ b/examples/01_icetray/04_i3_module_in_native_icetray_example.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, List, Sequence from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import ( @@ -23,7 +23,6 @@ from graphnet.deployment.i3modules import ( I3InferenceModule, - GraphNeTI3Module, ) ERROR_MESSAGE_MISSING_ICETRAY = ( @@ -43,7 +42,7 @@ def apply_to_files( i3_files: List[str], gcd_file: str, output_folder: str, - modules: Sequence["GraphNeTI3Module"], + modules: Sequence["I3InferenceModule"], ) -> None: """Will start an IceTray read/write chain with graphnet modules. diff --git a/examples/04_training/05_train_RNN_TITO.py b/examples/04_training/05_train_RNN_TITO.py new file mode 100644 index 000000000..26979fd5d --- /dev/null +++ b/examples/04_training/05_train_RNN_TITO.py @@ -0,0 +1,267 @@ +"""Example of training RNN-TITO model. + +with time-series data. +""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.models import StandardModel +from graphnet.models.detector.prometheus import Prometheus +from graphnet.models.gnn import RNN_TITO +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodeAsDOMTimeSeries +from graphnet.models.task.reconstruction import ( + DirectionReconstructionWithKappa, +) +from graphnet.training.labels import Direction +from graphnet.training.loss_functions import VonMisesFisher3DLoss +from graphnet.training.utils import make_train_validation_dataloader +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger + +# Constants +features = FEATURES.PROMETHEUS +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + }, + } + + graph_definition = KNNGraph( + detector=Prometheus(), + node_definition=NodeAsDOMTimeSeries( + keys=features, + id_columns=features[0:3], + time_column=features[-1], + charge_column="None", + ), + ) + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_RNN_TITO_model") + run_name = "RNN_TITO_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + ( + training_dataloader, + validation_dataloader, + ) = make_train_validation_dataloader( + db=config["path"], + graph_definition=graph_definition, + selection=None, + pulsemaps=config["pulsemap"], + features=features, + truth=truth, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + truth_table=truth_table, + index_column="event_no", + labels={ + "direction": Direction( + azimuth_key="injection_azimuth", zenith_key="injection_zenith" + ) + }, + ) + + # Building model + backbone = RNN_TITO( + nb_inputs=graph_definition.nb_outputs, + nb_neighbours=8, + time_series_columns=[4, 3], + rnn_layers=2, + rnn_hidden_size=64, + rnn_dropout=0.5, + features_subset=[0, 1, 2, 3], + dyntrans_layer_sizes=[(256, 256), (256, 256), (256, 256), (256, 256)], + post_processing_layer_sizes=[336, 256], + readout_layer_sizes=[256, 128], + global_pooling_schemes=["max"], + embedding_dim=0, + n_head=16, + use_global_features=True, + use_post_processing_layers=True, + ) + + task = DirectionReconstructionWithKappa( + hidden_size=backbone.nb_outputs, + target_labels=config["target"], + loss_function=VonMisesFisher3DLoss(), + ) + model = StandardModel( + graph_definition=graph_definition, + backbone=backbone, + tasks=[task], + optimizer_class=Adam, + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=ReduceLROnPlateau, + scheduler_kwargs={ + "patience": config["early_stopping_patience"], + }, + scheduler_config={ + "frequency": 1, + "monitor": "val_loss", + }, + ) + + # Training model + + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = [ + "injection_zenith", + "injection_azimuth", + "event_no", + ] + prediction_columns = [ + config["target"][0] + "_x_pred", + config["target"][0] + "_y_pred", + config["target"][0] + "_z_pred", + config["target"][0] + "_kappa_pred", + ] + + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes, + prediction_columns=prediction_columns, + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + # Save results as .csv + results.to_csv(f"{path}/results.csv") + + # Save full model (including weights) to .pth file - Not version proof + model.save(f"{path}/model.pth") + + # Save model config and state dict - Version safe save method. + model.save_state_dict(f"{path}/state_dict.pth") + model.save_config(f"{path}/model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser( + description=""" +Train GNN model without the use of config files. +""" + ) + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="direction", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + ("early-stopping-patience", 2), + ("batch-size", 16), + "num-workers", + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) diff --git a/requirements/torch_gpu.txt b/requirements/torch_gpu.txt index ddcb85038..d01bdf439 100644 --- a/requirements/torch_gpu.txt +++ b/requirements/torch_gpu.txt @@ -1,4 +1,4 @@ # Contains packages requirements for GPU installation --find-links https://download.pytorch.org/whl/torch_stable.html -torch==2.1.0+cu118 +torch==2.2.0+cu118 --find-links https://data.pyg.org/whl/torch-2.2.0+cu118.html diff --git a/src/graphnet/data/__init__.py b/src/graphnet/data/__init__.py index fbb1ee095..77cbc1af8 100644 --- a/src/graphnet/data/__init__.py +++ b/src/graphnet/data/__init__.py @@ -1,6 +1,9 @@ """Modules for converting and ingesting data. `graphnet.data` enables converting domain-specific data to industry-standard, -intermediate file formats and reading this data. +intermediate file formats and reading this data. """ -from .filters import I3Filter, I3FilterMask +from .extractors.icecube.utilities.i3_filters import I3Filter, I3FilterMask +from .dataconverter import DataConverter +from .pre_configured import I3ToParquetConverter +from .pre_configured import I3ToSQLiteConverter diff --git a/src/graphnet/data/dataclasses.py b/src/graphnet/data/dataclasses.py new file mode 100644 index 000000000..846fe6ffe --- /dev/null +++ b/src/graphnet/data/dataclasses.py @@ -0,0 +1,20 @@ +"""Module containing experiment-specific dataclasses.""" + +from typing import List, Any +from dataclasses import dataclass + + +@dataclass +class I3FileSet: # noqa: D101 + i3_file: str + gcd_file: str + + +@dataclass +class Settings: + """Dataclass for workers in I3Deployer.""" + + i3_files: List[str] + gcd_file: str + output_folder: str + modules: List[Any] diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 2a67ddce9..57f6005f6 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -1,496 +1,263 @@ -"""Base `DataConverter` class(es) used in GraphNeT.""" -# type: ignore[name-defined] # Due to use of `init_global_index`. - -from abc import ABC, abstractmethod -from collections import OrderedDict -from dataclasses import dataclass -from functools import wraps -import itertools +"""Contains `DataConverter`.""" +from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type +from abc import abstractmethod, ABC + +from tqdm import tqdm +import numpy as np from multiprocessing import Manager, Pool, Value import multiprocessing.pool from multiprocessing.sharedctypes import Synchronized +import pandas as pd import os -import re -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from glob import glob -import numpy as np -import pandas as pd -from tqdm import tqdm -from graphnet.data.utilities.random import pairwise_shuffle -from graphnet.data.extractors import ( - I3Extractor, - I3ExtractorCollection, - I3FeatureExtractor, - I3TruthExtractor, - I3GenericExtractor, -) from graphnet.utilities.decorators import final -from graphnet.utilities.filesys import find_i3_files -from graphnet.utilities.imports import has_icecube_package from graphnet.utilities.logging import Logger -from graphnet.data.filters import I3Filter, NullSplitI3Filter - -if has_icecube_package(): - from icecube import icetray, dataio # pyright: reportMissingImports=false - +from .readers.graphnet_file_reader import GraphNeTFileReader +from .writers.graphnet_writer import GraphNeTWriter +from .extractors import Extractor +from .extractors.icecube import I3Extractor +from .dataclasses import I3FileSet -SAVE_STRATEGIES = [ - "1:1", - "sequential_batched", - "pattern_batched", -] - -# Utility classes -@dataclass -class FileSet: # noqa: D101 - i3_file: str - gcd_file: str - - -# Utility method(s) def init_global_index(index: Synchronized, output_files: List[str]) -> None: """Make `global_index` available to pool workers.""" global global_index, global_output_files # type: ignore[name-defined] global_index, global_output_files = (index, output_files) # type: ignore[name-defined] -F = TypeVar("F", bound=Callable[..., Any]) - - -def cache_output_files(process_method: F) -> F: - """Decorate `process_method` to cache output file names.""" - - @wraps(process_method) - def wrapper(self: Any, *args: Any) -> Any: - try: - # Using multiprocessing - output_files = global_output_files # type: ignore[name-defined] - except NameError: # `global_output_files` not set - # Running on main process - output_files = self._output_files - - output_file = process_method(self, *args) - output_files.append(output_file) - return output_file - - return cast(F, wrapper) - - class DataConverter(ABC, Logger): - """Base class for converting I3-files to intermediate file format.""" + """A finalized data conversion class in GraphNeT. - @property - @abstractmethod - def file_suffix(self) -> str: - """Suffix to use on output files.""" + `DataConverter` provides parallel processing of file conversion and + extraction from experiment-specific file formats to graphnet-supported data + formats. This class also assigns event id's to training examples. + """ def __init__( self, - extractors: List[I3Extractor], + file_reader: GraphNeTFileReader, + save_method: GraphNeTWriter, outdir: str, - gcd_rescue: Optional[str] = None, - *, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - workers: int = 1, + extractors: Union[List[Extractor], List[I3Extractor]], index_column: str = "event_no", - icetray_verbose: int = 0, - i3_filters: List[I3Filter] = [], - ): - """Construct DataConverter. - - When using `input_file_batch_pattern`, regular expressions are used to - group files according to their names. All files that match a certain - pattern up to wildcards are grouped into the same output file. This - output file has the same name as the input files that are group into it, - with wildcards replaced with "x". Periods (.) and wildcards (*) have a - special meaning: Periods are interpreted as literal periods, and not as - matching any character (as in standard regex); and wildcards are - interpreted as ".*" in standard regex. - - For instance, the pattern "[A-Z]{1}_[0-9]{5}*.i3.zst" will find all I3 - files whose names contain: - - one capital letter, followed by - - an underscore, followed by - - five numbers, followed by - - any string of characters ending in ".i3.zst" - - This means that, e.g., the files: - - upgrade_genie_step4_141020_A_000000.i3.zst - - upgrade_genie_step4_141020_A_000001.i3.zst - - ... - - upgrade_genie_step4_141020_A_000008.i3.zst - - upgrade_genie_step4_141020_A_000009.i3.zst - would be grouped into the output file named - "upgrade_genie_step4_141020_A_00000x." but the file - - upgrade_genie_step4_141020_A_000010.i3.zst - would end up in a separate group, named - "upgrade_genie_step4_141020_A_00001x.". - """ - # Check(s) - if not isinstance(extractors, (list, tuple)): - extractors = [extractors] - - assert ( - len(extractors) > 0 - ), "Please specify at least one argument of type I3Extractor" - - for extractor in extractors: - assert isinstance( - extractor, I3Extractor - ), f"{type(extractor)} is not a subclass of I3Extractor" - - # Infer saving strategy - save_strategy = self._infer_save_strategy( - nb_files_to_batch, - sequential_batch_pattern, - input_file_batch_pattern, - ) - - # Member variables - self._outdir = outdir - self._gcd_rescue = gcd_rescue - self._save_strategy = save_strategy - self._nb_files_to_batch = nb_files_to_batch - self._sequential_batch_pattern = sequential_batch_pattern - self._input_file_batch_pattern = input_file_batch_pattern - self._workers = workers - - # I3Filters (NullSplitI3Filter is always included) - self._i3filters = [NullSplitI3Filter()] + i3_filters - - for filter in self._i3filters: - assert isinstance( - filter, I3Filter - ), f"{type(filter)} is not a subclass of I3Filter" - - # Create I3Extractors - self._extractors = I3ExtractorCollection(*extractors) - - # Create shorthand of names of all pulsemaps queried - self._table_names = [extractor.name for extractor in self._extractors] - self._pulsemaps = [ - extractor.name - for extractor in self._extractors - if isinstance(extractor, I3FeatureExtractor) - ] + num_workers: int = 1, + ) -> None: + """Initialize `DataConverter`. - # Placeholders for keeping track of sequential event indices and output files + Args: + file_reader: The method used for reading and applying `Extractors`. + save_method: The method used to save the interim data format to + a graphnet supported file format. + outdir: The directory to save the files in. + extractors: The `Extractor`(s) that will be applied to the input + files. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + """ + # Member Variable Assignment + self._file_reader = file_reader + self._save_method = save_method + self._num_workers = num_workers self._index_column = index_column self._index = 0 + self._output_dir = outdir self._output_files: List[str] = [] - # Set verbosity - if icetray_verbose == 0: - icetray.I3Logger.global_logger = icetray.I3NullLogger() + # Set Extractors. Will throw error if extractors are incompatible + # with reader. + if not isinstance(extractors, list): + extractors = [extractors] + self._file_reader.set_extractors(extractors=extractors) # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @final - def __call__( - self, - directories: Union[str, List[str]], - recursive: Optional[bool] = True, - ) -> None: - """Convert I3-files in `directories. + def __call__(self, input_dir: Union[str, List[str]]) -> None: + """Extract data from files in `input_dir` and save to disk. Args: - directories: One or more directories, the I3 files within which - should be converted to an intermediate file format. - recursive: Whether or not to search the directories recursively. + input_dir: A directory that contains the input files. + The directory will be searched recursively for files + matching the file extension. """ - # Find all I3 and GCD files in the specified directories. - i3_files, gcd_files = find_i3_files( - directories, self._gcd_rescue, recursive - ) - if len(i3_files) == 0: - self.error(f"No files found in {directories}.") - return - - # Save a record of the found I3 files in the output directory. - self._save_filenames(i3_files) - - # Shuffle I3 files to get a more uniform load on worker nodes. - i3_files, gcd_files = pairwise_shuffle(i3_files, gcd_files) - - # Process the files - filesets = [ - FileSet(i3_file, gcd_file) - for i3_file, gcd_file in zip(i3_files, gcd_files) + # Get the file reader to produce a list of input files + # in the directory + input_files = self._file_reader.find_files(path=input_dir) + self._launch_jobs(input_files=input_files) + self._output_files = [ + os.path.join( + self._output_dir, + self._create_file_name(file) + + self._save_method.file_extension, + ) + for file in input_files ] - self.execute(filesets) @final - def execute(self, filesets: List[FileSet]) -> None: - """General method for processing a set of I3 files. - - The files are converted individually according to the inheriting class/ - intermediate file format. - - Args: - filesets: List of paths to I3 and corresponding GCD files. - """ - # Make sure output directory exists. - self.info(f"Saving results to {self._outdir}") - os.makedirs(self._outdir, exist_ok=True) - - # Iterate over batches of files. - try: - if self._save_strategy == "sequential_batched": - # Define batches - assert self._nb_files_to_batch is not None - assert self._sequential_batch_pattern is not None - batches = np.array_split( - np.asarray(filesets), - int(np.ceil(len(filesets) / self._nb_files_to_batch)), - ) - batches = [ - ( - group.tolist(), - self._sequential_batch_pattern.format(ix_batch), - ) - for ix_batch, group in enumerate(batches) - ] - self.info( - f"Will batch {len(filesets)} input files into {len(batches)} groups." - ) - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "pattern_batched": - # Define batches - groups: Dict[str, List[FileSet]] = OrderedDict() - for fileset in sorted(filesets, key=lambda f: f.i3_file): - group = re.sub( - self._sub_from, - self._sub_to, - os.path.basename(fileset.i3_file), - ) - if group not in groups: - groups[group] = list() - groups[group].append(fileset) - - self.info( - f"Will batch {len(filesets)} input files into {len(groups)} groups" - ) - if len(groups) <= 20: - for group, group_filesets in groups.items(): - self.info( - f"> {group}: {len(group_filesets):3d} file(s)" - ) - - batches = [ - (list(group_filesets), group) - for group, group_filesets in groups.items() - ] - - # Iterate over batches - pool = self._iterate_over_batches_of_files(batches) - - elif self._save_strategy == "1:1": - pool = self._iterate_over_individual_files(filesets) - - else: - assert False, "Shouldn't reach here." - - self._update_shared_variables(pool) - - except KeyboardInterrupt: - self.warning("[ctrl+c] Exciting gracefully.") - - @abstractmethod - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Implementation-specific method for saving data to file. - - Args: - data: List of extracted features. - output_file: Name of output file. - """ - - @abstractmethod - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None + def _launch_jobs( + self, + input_files: Union[List[str], List[I3FileSet]], ) -> None: - """Implementation-specific method for merging output files. + """Multi Processing Logic. - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - backend in question. - """ + Spawns worker pool, + distributes the input files evenly across workers. + declare event_no as globally accessible variable across workers. + starts jobs. - # Internal methods - def _iterate_over_individual_files( - self, args: List[FileSet] - ) -> Optional[multiprocessing.pool.Pool]: + Will call process_file in parallel. + """ # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args)) + map_fn, pool = self.get_map_function(nb_files=len(input_files)) # Iterate over files for _ in map_fn( - self._process_file, tqdm(args, unit="file(s)", colour="green") + self._process_file, + tqdm(input_files, unit="file(s)", colour="green"), ): - self.debug( - "Saving with 1:1 strategy on the individual worker processes" - ) - - return pool + self.debug("processing file.") - def _iterate_over_batches_of_files( - self, args: List[Tuple[List[FileSet], str]] - ) -> Optional[multiprocessing.pool.Pool]: - """Iterate over a batch of files and save results on worker process.""" - # Get appropriate mapping function - map_fn, pool = self.get_map_function(len(args), unit="batch(es)") - - # Iterate over batches of files - for _ in map_fn( - self._process_batch, tqdm(args, unit="batch(es)", colour="green") - ): - self.debug("Saving with batched strategy") + self._update_shared_variables(pool) - return pool + @final + def _process_file(self, file_path: Union[str, I3FileSet]) -> None: + """Process a single file. - def _update_shared_variables( - self, pool: Optional[multiprocessing.pool.Pool] - ) -> None: - """Update `self._index` and `self._output_files`. + Calls file reader to recieve extracted output, event ids + is assigned to the extracted data and is handed to save method. - If `pool` is set, it means that multiprocessing was used. In this case, - the worker processes will not have been able to write directly to - `self._index` and `self._output_files`, and we need to get them synced - up. + This function is called in parallel. """ - if pool: - # Extract information from shared variables to member variables. - index, output_files = pool._initargs # type: ignore - self._index += index.value - self._output_files.extend(list(sorted(output_files[:]))) - - @cache_output_files - def _process_file( - self, - fileset: FileSet, - ) -> str: + # Read and apply extractors + data: List[OrderedDict] = self._file_reader(file_path=file_path) - # Process individual files - data = self._extract_data(fileset) + # Count number of events + n_events = len(data) - # Save data - output_file = self._get_output_file(fileset.i3_file) - self.save_data(data, output_file) + # Assign event_no's to each event in data and transform to pd.DataFrame + dataframes = self._assign_event_no(data=data) - return output_file + # Delete `data` to save memory + del data - @cache_output_files - def _process_batch(self, args: Tuple[List[FileSet], str]) -> str: - # Unpack arguments - filesets, output_file_name = args + # Create output file name + output_file_name = self._create_file_name(input_file_path=file_path) - # Process individual files - data = list( - itertools.chain.from_iterable(map(self._extract_data, filesets)) + # Apply save method + self._save_method( + data=dataframes, + file_name=output_file_name, + n_events=n_events, + output_dir=self._output_dir, ) - # Save batched data - output_file = self._get_output_file(output_file_name) - self.save_data(data, output_file) - - return output_file - - def _extract_data(self, fileset: FileSet) -> List[OrderedDict]: - """Extract data from single I3 file. - - If the saving strategy is 1:1 (i.e., each I3 file is converted to a - corresponding intermediate file) the data is saved to such a file, and - no data is return from the method. + @final + def _create_file_name(self, input_file_path: Union[str, I3FileSet]) -> str: + """Convert input file path to an output file name.""" + if isinstance(input_file_path, I3FileSet): + input_file_path = input_file_path.i3_file + file_name = os.path.basename(input_file_path) + for ext in self._file_reader._accepted_file_extensions: + if file_name.endswith(ext): + file_name_without_extension = file_name.replace(ext, "") + return file_name_without_extension.replace(".i3", "") - The above distincting is to allow worker processes to save files rather - than sending it back to the main process. + @final + def _assign_event_no( + self, data: List[OrderedDict] + ) -> Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]]: + + # Request event_no's for the entire file + event_nos = self._request_event_nos(n_ids=len(data)) + + # Dict holding pd.DataFrame's + dataframe_dict: Dict = {} + # Loop through events (again..) to assign event_nos + for k in range(len(data)): + for extractor_name in data[k].keys(): + n_rows = self._count_rows( + event_dict=data[k], extractor_name=extractor_name + ) + if n_rows > 0: + data[k][extractor_name][self._index_column] = np.repeat( + event_nos[k], n_rows + ).tolist() + df = pd.DataFrame( + data[k][extractor_name], + index=[0] if n_rows == 1 else None, + ) + if extractor_name in dataframe_dict.keys(): + dataframe_dict[extractor_name].append(df) + else: + dataframe_dict[extractor_name] = [df] + + # Merge each list of dataframes if wanted by writer + if self._save_method.expects_merged_dataframes: + for key in dataframe_dict.keys(): + dataframe_dict[key] = pd.concat( + dataframe_dict[key], axis=0 + ).reset_index(drop=True) + return dataframe_dict - Args: - fileset: Path to I3 file and corresponding GCD file. + @final + def _count_rows( + self, event_dict: OrderedDict[str, Any], extractor_name: str + ) -> int: + """Count number of rows that features from `extractor_name` have.""" + extractor_dict = event_dict[extractor_name] - Returns: - Extracted data. - """ - # Infer whether method is being run using multiprocessing try: - global_index # type: ignore[name-defined] - multi_processing = True - except NameError: - multi_processing = False - - self._extractors.set_files(fileset.i3_file, fileset.gcd_file) - i3_file_io = dataio.I3File(fileset.i3_file, "r") - data = list() - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except Exception as e: - if "I3" in str(e): - continue - # check if frame should be skipped - if self._skip_frame(frame): - continue - - # Try to extract data from I3Frame - results = self._extractors(frame) - - data_dict = OrderedDict(zip(self._table_names, results)) - - # If an I3GenericExtractor is used, we want each automatically - # parsed key to be stored as a separate table. - for extractor in self._extractors: - if isinstance(extractor, I3GenericExtractor): - data_dict.update(data_dict.pop(extractor._name)) - - # Get new, unique index and increment value - if multi_processing: - with global_index.get_lock(): # type: ignore[name-defined] - index = global_index.value # type: ignore[name-defined] - global_index.value += 1 # type: ignore[name-defined] + # If all features in extractor_name have the same length + # this line of code will execute without error and result + # in an array with shape [num_features, n_rows_in_feature] + # unless the list is empty! + + shape = np.asarray(list(extractor_dict.values())).shape + if len(shape) > 1: + n_rows = shape[1] else: - index = self._index - self._index += 1 - - # Attach index to all tables - for table in data_dict.keys(): - data_dict[table][self._index_column] = index - - data.append(data_dict) + n_rows = 1 + except ValueError as e: + self.error( + f"Features from {extractor_name} ({extractor_dict.keys()}) have different lengths." + ) + raise e + return n_rows + + def _request_event_nos(self, n_ids: int) -> List[int]: + + # Get new, unique index and increment value + if self._num_workers > 1: + with global_index.get_lock(): # type: ignore[name-defined] + starting_index = global_index.value # type: ignore[name-defined] + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + global_index.value += n_ids # type: ignore[name-defined] + else: + starting_index = self._index + event_nos = np.arange( + starting_index, starting_index + n_ids, 1 + ).tolist() + self._index += n_ids - return data + return event_nos + @final def get_map_function( - self, nb_files: int, unit: str = "I3 file(s)" + self, nb_files: int, unit: str = "file(s)" ) -> Tuple[Any, Optional[multiprocessing.pool.Pool]]: """Identify map function to use (pure python or multiprocess).""" # Choose relevant map-function given the requested number of workers. - workers = min(self._workers, nb_files) - if workers > 1: + n_workers = min(self._num_workers, nb_files) + if n_workers > 1: self.info( - f"Starting pool of {workers} workers to process {nb_files} {unit}" + f"Starting pool of {n_workers} workers to process {nb_files} {unit}" ) manager = Manager() @@ -498,7 +265,7 @@ def get_map_function( output_files = manager.list() pool = Pool( - processes=workers, + processes=n_workers, initializer=init_global_index, initargs=(index, output_files), ) @@ -513,75 +280,52 @@ def get_map_function( return map_fn, pool - def _infer_save_strategy( - self, - nb_files_to_batch: Optional[int] = None, - sequential_batch_pattern: Optional[str] = None, - input_file_batch_pattern: Optional[str] = None, - ) -> str: - if input_file_batch_pattern is not None: - save_strategy = "pattern_batched" - - assert ( - "*" in input_file_batch_pattern - ), "Argument `input_file_batch_pattern` should contain at least one wildcard (*)" - - fields = [ - "(" + field + ")" - for field in input_file_batch_pattern.replace( - ".", r"\." - ).split("*") - ] - nb_fields = len(fields) - self._sub_from = ".*".join(fields) - self._sub_to = "x".join([f"\\{ix + 1}" for ix in range(nb_fields)]) - - if sequential_batch_pattern is not None: - self.warning("Argument `sequential_batch_pattern` ignored.") - if nb_files_to_batch is not None: - self.warning("Argument `nb_files_to_batch` ignored.") - - elif (nb_files_to_batch is not None) or ( - sequential_batch_pattern is not None - ): - save_strategy = "sequential_batched" + @final + def _update_shared_variables( + self, pool: Optional[multiprocessing.pool.Pool] + ) -> None: + """Update `self._index` and `self._output_files`. - assert (nb_files_to_batch is not None) and ( - sequential_batch_pattern is not None - ), "Please specify both `nb_files_to_batch` and `sequential_batch_pattern` for sequential batching." + If `pool` is set, it means that multiprocessing was used. In this case, + the worker processes will not have been able to write directly to + `self._index` and `self._output_files`, and we need to get them synced + up. + """ + if pool: + # Extract information from shared variables to member variables. + index, output_files = pool._initargs # type: ignore + self._index += index.value + self._output_files.extend(list(sorted(output_files[:]))) - else: - save_strategy = "1:1" - - return save_strategy - - def _save_filenames(self, i3_files: List[str]) -> None: - """Save I3 file names in CSV format.""" - self.debug("Saving input file names to config CSV.") - config_dir = os.path.join(self._outdir, "config") - os.makedirs(config_dir, exist_ok=True) - df_i3_files = pd.DataFrame(data=i3_files, columns=["filename"]) - df_i3_files.to_csv(os.path.join(config_dir, "i3files.csv")) - - def _get_output_file(self, input_file: str) -> str: - assert isinstance(input_file, str) - basename = os.path.basename(input_file) - output_file = os.path.join( - self._outdir, - re.sub(r"\.i3\..*", "", basename) + "." + self.file_suffix, - ) - return output_file + @final + def merge_files(self, files: Optional[List[str]] = None) -> None: + """Merge converted files. - def _skip_frame(self, frame: "icetray.I3Frame") -> bool: - """Check the user defined filters. + `DataConverter` will call the `.merge_files` method in the + `GraphNeTWriter` module that it was instantiated with. - Returns: - bool: True if frame should be skipped, False otherwise. + Args: + files: Intermediate files to be merged. """ - if self._i3filters is None: - return False # No filters defined, so we keep the frame - - for filter in self._i3filters: - if not filter(frame): - return True # keep_frame call false, skip the frame. - return False # All filter keep_frame calls true, keep the frame. + if (files is None) & (len(self._output_files) > 0): + # If no input files are given, but output files from conversion + # is available. + files_to_merge = self._output_files + elif files is not None: + # Proceed to merge specified by user. + files_to_merge = files + else: + # Raise error + self.error( + "This DataConverter does not have output files set," + "and you must therefore specify argument `files`." + ) + assert files is not None + + # Merge files + merge_path = os.path.join(self._output_dir, "merged") + self.info(f"Merging files to {merge_path}") + self._save_method.merge_files( + files=files_to_merge, + output_dir=merge_path, + ) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py new file mode 100644 index 000000000..e629ce4a0 --- /dev/null +++ b/src/graphnet/data/datamodule.py @@ -0,0 +1,456 @@ +"""Base `Dataloader` class(es) used in `graphnet`.""" +from typing import Dict, Any, Optional, List, Tuple, Union +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from copy import deepcopy +from sklearn.model_selection import train_test_split +import pandas as pd + +from graphnet.data.dataset import ( + Dataset, + EnsembleDataset, + SQLiteDataset, + ParquetDataset, +) +from graphnet.utilities.logging import Logger + + +class GraphNeTDataModule(pl.LightningDataModule, Logger): + """General Class for DataLoader Construction.""" + + def __init__( + self, + dataset_reference: Union[SQLiteDataset, ParquetDataset, Dataset], + dataset_args: Dict[str, Any], + selection: Optional[Union[List[int], List[List[int]]]] = None, + test_selection: Optional[Union[List[int], List[List[int]]]] = None, + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + train_val_split: Optional[List[float]] = [0.9, 0.10], + split_seed: int = 42, + ) -> None: + """Create dataloaders from dataset. + + Args: + dataset_reference: A non-instantiated reference + to the dataset class. + dataset_args: Arguments to instantiate + graphnet.data.dataset.Dataset with. + selection: (Optional) a list of event id's used for training + and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, + Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, + Default None. + validation_dataloader_kwargs: Arguments for the validation + DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, + Default None. + train_val_split (Optional): Split ratio for training and + validation sets. Default is [0.9, 0.10]. + split_seed: seed used for shuffling and splitting selections into + train/validation, Default 42. + """ + Logger.__init__(self) + self._make_sure_root_logger_is_configured() + self._dataset = dataset_reference + self._dataset_args = dataset_args + self._selection = selection + self._test_selection = test_selection + self._train_val_split = train_val_split or [0.0] + self._rng = split_seed + + self._train_dataloader_kwargs = train_dataloader_kwargs or {} + self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} + self._test_dataloader_kwargs = test_dataloader_kwargs or {} + + # If multiple dataset paths are given, we should use EnsembleDataset + self._use_ensemble_dataset = isinstance( + self._dataset_args["path"], list + ) + + self.setup("fit") + + def prepare_data(self) -> None: + """Prepare the dataset for training.""" + # Download method for curated datasets. Method for download is + # likely dataset-specific, so we can leave it as-is + pass + + def setup(self, stage: str) -> None: + """Prepare Datasets for DataLoaders. + + Args: + stage: lightning stage. Either "fit, validate, test, predict" + """ + # Sanity Checks + self._validate_dataset_class() + self._validate_dataset_args() + self._validate_dataloader_args() + + # Case-handling of selection arguments + self._resolve_selections() + + # Creation of Datasets + if ( + self._test_selection is not None + or len(self._test_dataloader_kwargs) > 0 + ): + self._test_dataset = self._create_dataset( + self._test_selection # type: ignore + ) + if stage == "fit" or stage == "validate": + if self._train_selection is not None: + self._train_dataset = self._create_dataset( + self._train_selection + ) + if self._val_selection is not None: + self._val_dataset = self._create_dataset(self._val_selection) + + return + + @property + def train_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the training DataLoader. + + Returns: + DataLoader: The DataLoader configured for training. + """ + return self._create_dataloader(self._train_dataset) + + @property + def val_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the validation DataLoader. + + Returns: + DataLoader: The DataLoader configured for validation. + """ + return self._create_dataloader(self._val_dataset) + + @property + def test_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the test DataLoader. + + Returns: + DataLoader: The DataLoader configured for testing. + """ + return self._create_dataloader(self._test_dataset) + + def teardown(self) -> None: # type: ignore[override] + """Perform any necessary cleanup or shutdown procedures. + + This method can be used for tasks such as closing SQLite connections + after training. Override this method as needed. + + Returns: + None + """ + if hasattr(self, "_train_dataset") and isinstance( + self._train_dataset, SQLiteDataset + ): + self._train_dataset._close_connection() + + if hasattr(self, "_val_dataset") and isinstance( + self._val_dataset, SQLiteDataset + ): + self._val_dataset._close_connection() + + if hasattr(self, "_test_dataset") and isinstance( + self._test_dataset, SQLiteDataset + ): + self._test_dataset._close_connection() + + return + + def _create_dataloader( + self, dataset: Union[Dataset, EnsembleDataset] + ) -> DataLoader: + """Create a DataLoader for the given dataset. + + Args: + dataset (Union[Dataset, EnsembleDataset]): + The dataset to create a DataLoader for. + + Returns: + DataLoader: The DataLoader configured for the given dataset. + """ + if dataset == self._train_dataset: + dataloader_args = self._train_dataloader_kwargs + elif dataset == self._val_dataset: + dataloader_args = self._validation_dataloader_kwargs + elif dataset == self._test_dataset: + dataloader_args = self._test_dataloader_kwargs + else: + raise ValueError( + "Unknown dataset encountered during dataloader creation." + ) + + if dataloader_args is None: + raise AttributeError("Dataloader arguments not provided.") + + return DataLoader(dataset=dataset, **dataloader_args) + + def _validate_dataset_class(self) -> None: + """Sanity checks on the dataset reference (self._dataset). + + Checks whether the dataset is an instance of SQLiteDataset, + ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset + type is detected, or if an EnsembleDataset is used. + """ + allowed_types = (SQLiteDataset, ParquetDataset, Dataset) + if self._dataset not in allowed_types: + raise TypeError( + "dataset_reference must be an instance " + "of SQLiteDataset, ParquetDataset, or Dataset." + ) + if self._dataset is EnsembleDataset: + raise TypeError( + "EnsembleDataset is not allowed as dataset_reference." + ) + + def _validate_dataset_args(self) -> None: + """Sanity checks on the arguments for the dataset reference.""" + if isinstance(self._dataset_args["path"], list): + if self._selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of selections given as arg. + assert len(self._dataset_args["path"]) == len( + self._selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths" + f" ({len(self._dataset_args['path'])})" + " does not match the number of" + f" selections ({len(self._selection)})." + ) + + if self._test_selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of test selections. + assert len(self._dataset_args["path"]) == len( + self._test_selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths " + f" ({len(self._dataset_args['path'])}) does not match " + "the number of test selections " + f"({len(self._test_selection)}).If you'd like to test " + "on only a subset of the " + f"{len(self._dataset_args['path'])} datasets, " + "please provide empty test selections for the others." + ) + + def _validate_dataloader_args(self) -> None: + """Sanity check on `dataloader_args`.""" + if "dataset" in self._train_dataloader_kwargs: + raise ValueError( + "`train_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._validation_dataloader_kwargs: + raise ValueError( + "`validation_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._test_dataloader_kwargs: + raise ValueError( + "`test_dataloader_kwargs` must not contain `dataset`" + ) + + def _resolve_selections(self) -> None: + if self._test_selection is None: + self.warning_once( + f"{self.__class__.__name__} did not receive an" + " argument for `test_selection` and will " + "therefore not have a prediction dataloader available." + ) + if self._selection is not None: + # Split the selection into train/validation + if self._use_ensemble_dataset: + # Split every selection + self._train_selection = [] + self._val_selection = [] + for selection in self._selection: + train_selection, val_selection = self._split_selection( + selection + ) + self._train_selection.append(train_selection) + self._val_selection.append(val_selection) + + else: + # Split the only selection we got + assert isinstance(self._selection, list) + ( + self._train_selection, + self._val_selection, + ) = self._split_selection( # type: ignore + self._selection + ) + + else: # selection is None + # If not provided, we infer it by grabbing + # all event ids in the dataset. + self.info( + f"{self.__class__.__name__} did not receive an" + " for `selection`. Selection will " + "will automatically be created with a split of " + f"train: {self._train_val_split[0]} and " + f"validation: {self._train_val_split[1]}" + ) + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections() # type: ignore + + def _split_selection( + self, selection: Union[int, List[int], List[List[int]]] + ) -> Tuple[List[int], List[int]]: + """Split train selection into train/validation. + + Args: + selection: Training selection to be split + + Returns: + Training selection, Validation selection. + """ + assert isinstance(selection, (int, list)) + if isinstance(selection, int): + flat_selection = [selection] + elif isinstance(selection[0], list): + flat_selection = [ + item + for sublist in selection + for item in sublist # type: ignore + ] + else: + flat_selection = selection # type: ignore + assert isinstance(flat_selection, list) + + train_selection, val_selection = train_test_split( + flat_selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + return train_selection, val_selection + + def _infer_selections(self) -> Tuple[List[int], List[int]]: + """Automatically infer training and validation selections. + + Returns: + Training selection, Validation selection + """ + if self._use_ensemble_dataset: + # We must iterate through the dataset paths and infer a train/val + # selection for each. + self._train_selection = [] + self._val_selection = [] + for dataset_path in self._dataset_args["path"]: + ( + train_selection, + val_selection, + ) = self._infer_selections_on_single_dataset(dataset_path) + self._train_selection.append(train_selection) # type: ignore + self._val_selection.append(val_selection) # type: ignore + else: + # Infer selection on a single dataset + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections_on_single_dataset( # type: ignore + self._dataset_args["path"] + ) + + return (self._train_selection, self._val_selection) # type: ignore + + def _infer_selections_on_single_dataset( + self, dataset_path: str + ) -> Tuple[List[int], List[int]]: + """Automatically infers dataset train/val selections. + + Args: + dataset_path (str): The path to the dataset. + + Returns: + Tuple[List[int], List[int]]: Training and validation selections. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = dataset_path + tmp_dataset = self._construct_dataset(tmp_args) + + all_events = ( + tmp_dataset._get_all_indices() + ) # unshuffled list, sequential index + + # Multiple lines to avoid one large + all_events = ( + pd.DataFrame(all_events) + .sample(frac=1, replace=False, random_state=self._rng) + .values.tolist() + ) # shuffled list + + return self._split_selection(all_events) + + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: + """Construct dataset. + + Return: + Dataset object constructed from input arguments. + """ + dataset = self._dataset(**tmp_args) # type: ignore + return dataset + + def _create_dataset( + self, selection: Union[List[int], List[List[int]], List[float]] + ) -> Union[EnsembleDataset, Dataset]: + """Instantiate `dataset_reference`. + + Args: + selection: The selected event id's. + + Returns: + A dataset, either an instance of `EnsembleDataset` or `Dataset`. + """ + if self._use_ensemble_dataset: + # Construct multiple datasets and pass to EnsembleDataset + # len(selection) == len(dataset_args['path']) + datasets = [] + for dataset_idx in range(len(selection)): + datasets.append( + self._create_single_dataset( + selection=selection[dataset_idx], # type: ignore + path=self._dataset_args["path"][dataset_idx], + ) + ) + + dataset = EnsembleDataset(datasets) + + else: + # Construct single dataset + dataset = self._create_single_dataset( + selection=selection, + path=self._dataset_args["path"], # type:ignore + ) + return dataset + + def _create_single_dataset( + self, + selection: Union[List[int], List[List[int]], List[float]], + path: str, + ) -> Dataset: + """Instantiate a single `Dataset`. + + Args: + selection: A selection for a single dataset. + path: Path to a single dataset + + Returns: + An instance of `Dataset`. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = path + tmp_args["selection"] = selection + return self._construct_dataset(tmp_args) diff --git a/src/graphnet/data/extractors/__init__.py b/src/graphnet/data/extractors/__init__.py index e1d4895bf..c6f4f325e 100644 --- a/src/graphnet/data/extractors/__init__.py +++ b/src/graphnet/data/extractors/__init__.py @@ -1,20 +1,2 @@ -"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" - -from .i3extractor import I3Extractor, I3ExtractorCollection -from .i3featureextractor import ( - I3FeatureExtractor, - I3FeatureExtractorIceCube86, - I3FeatureExtractorIceCubeDeepCore, - I3FeatureExtractorIceCubeUpgrade, - I3PulseNoiseTruthFlagIceCubeUpgrade, -) -from .i3truthextractor import I3TruthExtractor -from .i3retroextractor import I3RetroExtractor -from .i3splinempeextractor import I3SplineMPEICExtractor -from .i3particleextractor import I3ParticleExtractor -from .i3tumextractor import I3TUMExtractor -from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor -from .i3genericextractor import I3GenericExtractor -from .i3pisaextractor import I3PISAExtractor -from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor -from .i3quesoextractor import I3QUESOExtractor +"""Module containing data-specific extractor modules.""" +from .extractor import Extractor diff --git a/src/graphnet/data/extractors/extractor.py b/src/graphnet/data/extractors/extractor.py new file mode 100644 index 000000000..ce743f63d --- /dev/null +++ b/src/graphnet/data/extractors/extractor.py @@ -0,0 +1,46 @@ +"""Base I3Extractor class(es).""" +from typing import Any +from abc import ABC, abstractmethod + +from graphnet.utilities.logging import Logger + + +class Extractor(ABC, Logger): + """Base class for extracting information from data files. + + All classes inheriting from `Extractor` should implement the `__call__` + method, and should return a pure python dictionary on the form + + output = {'var1: .., + ... , + 'var_n': ..} + + Variables can be scalar or array-like of shape [n, 1], where n denotes the + number of elements in the array, and 1 the number of columns. + + An extractor is used in conjunction with a specific `FileReader`. + """ + + def __init__(self, extractor_name: str): + """Construct Extractor. + + Args: + extractor_name: Name of the `Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. E.g. "mc_truth". + """ + # Member variable(s) + self._extractor_name: str = extractor_name + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @abstractmethod + def __call__(self, data: Any) -> dict: + """Extract information from data.""" + pass + + @property + def name(self) -> str: + """Get the name of the `Extractor` instance.""" + return self._extractor_name diff --git a/src/graphnet/data/extractors/i3extractor.py b/src/graphnet/data/extractors/i3extractor.py deleted file mode 100644 index 90a982387..000000000 --- a/src/graphnet/data/extractors/i3extractor.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Base I3Extractor class(es).""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from graphnet.utilities.imports import has_icecube_package -from graphnet.utilities.logging import Logger - -if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray, dataio # pyright: reportMissingImports=false - - -class I3Extractor(ABC, Logger): - """Base class for extracting information from physics I3-frames. - - All classes inheriting from `I3Extractor` should implement the `__call__` - method, and can be applied directly on `icetray.I3Frame` objects to return - extracted, pure-python data. - """ - - def __init__(self, name: str): - """Construct I3Extractor. - - Args: - name: Name of the `I3Extractor` instance. Used to keep track of the - provenance of different data, and to name tables to which this - data is saved. - """ - # Member variable(s) - self._i3_file: str = "" - self._gcd_file: str = "" - self._gcd_dict: Dict[int, Any] = {} - self._calibration: Optional["icetray.I3Frame.Calibration"] = None - self._name: str = name - - # Base class constructor - super().__init__(name=__name__, class_name=self.__class__.__name__) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - # @TODO: Is it necessary to set the `i3_file`? It is only used in one - # place in `I3TruthExtractor`, and there only in a way that might - # be solved another way. - self._i3_file = i3_file - self._gcd_file = gcd_file - self._load_gcd_data() - - def _load_gcd_data(self) -> None: - """Load the geospatial information contained in the GCD-file.""" - # If no GCD file is provided, search the I3 file for frames containing - # geometry (G) and calibration (C) information. - gcd_file = dataio.I3File(self._gcd_file or self._i3_file) - - try: - g_frame = gcd_file.pop_frame(icetray.I3Frame.Geometry) - except RuntimeError: - self.error( - "No GCD file was provided and no G-frame was found. Exiting." - ) - raise - else: - self._gcd_dict = g_frame["I3Geometry"].omgeo - - try: - c_frame = gcd_file.pop_frame(icetray.I3Frame.Calibration) - except RuntimeError: - self.warning("No GCD file was provided and no C-frame was found.") - else: - self._calibration = c_frame["I3Calibration"] - - @abstractmethod - def __call__(self, frame: "icetray.I3Frame") -> dict: - """Extract information from frame.""" - pass - - @property - def name(self) -> str: - """Get the name of the `I3Extractor` instance.""" - return self._name - - -class I3ExtractorCollection(list): - """Class to manage multiple I3Extractors.""" - - def __init__(self, *extractors: I3Extractor): - """Construct I3ExtractorCollection. - - Args: - *extractors: List of `I3Extractor`s to be treated as a single - collection. - """ - # Check(s) - for extractor in extractors: - assert isinstance(extractor, I3Extractor) - - # Base class constructor - super().__init__(extractors) - - def set_files(self, i3_file: str, gcd_file: str) -> None: - """Store references to the I3- and GCD-files being processed.""" - for extractor in self: - extractor.set_files(i3_file, gcd_file) - - def __call__(self, frame: "icetray.I3Frame") -> List[dict]: - """Extract information from frame for each member `I3Extractor`.""" - return [extractor(frame) for extractor in self] diff --git a/src/graphnet/data/extractors/i3particleextractor.py b/src/graphnet/data/extractors/i3particleextractor.py deleted file mode 100644 index bd37424d2..000000000 --- a/src/graphnet/data/extractors/i3particleextractor.py +++ /dev/null @@ -1,43 +0,0 @@ -"""I3Extractor class(es) for extracting I3Particle properties.""" - -from typing import TYPE_CHECKING, Dict - -from graphnet.data.extractors.i3extractor import I3Extractor - -if TYPE_CHECKING: - from icecube import icetray # pyright: reportMissingImports=false - - -class I3ParticleExtractor(I3Extractor): - """Class for extracting I3Particle properties. - - Can be used to extract predictions from other algorithms for comparisons - with GraphNeT. - """ - - def __init__(self, name: str): - """Construct I3ParticleExtractor.""" - # Base class constructor - super().__init__(name) - - def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: - """Extract I3Particle properties from I3Particle in frame.""" - output = {} - if self._name in frame: - output.update( - { - "zenith_" + self._name: frame[self._name].dir.zenith, - "azimuth_" + self._name: frame[self._name].dir.azimuth, - "dir_x_" + self._name: frame[self._name].dir.x, - "dir_y_" + self._name: frame[self._name].dir.y, - "dir_z_" + self._name: frame[self._name].dir.z, - "pos_x_" + self._name: frame[self._name].pos.x, - "pos_y_" + self._name: frame[self._name].pos.y, - "pos_z_" + self._name: frame[self._name].pos.z, - "time_" + self._name: frame[self._name].time, - "speed_" + self._name: frame[self._name].speed, - "energy_" + self._name: frame[self._name].energy, - } - ) - - return output diff --git a/src/graphnet/data/extractors/icecube/__init__.py b/src/graphnet/data/extractors/icecube/__init__.py new file mode 100644 index 000000000..11befe581 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/__init__.py @@ -0,0 +1,20 @@ +"""Collection of I3Extractors, extracting pure-python data from I3Frames.""" + +from .i3extractor import I3Extractor +from .i3featureextractor import ( + I3FeatureExtractor, + I3FeatureExtractorIceCube86, + I3FeatureExtractorIceCubeDeepCore, + I3FeatureExtractorIceCubeUpgrade, + I3PulseNoiseTruthFlagIceCubeUpgrade, +) +from .i3truthextractor import I3TruthExtractor +from .i3retroextractor import I3RetroExtractor +from .i3splinempeextractor import I3SplineMPEICExtractor +from .i3particleextractor import I3ParticleExtractor +from .i3tumextractor import I3TUMExtractor +from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor +from .i3genericextractor import I3GenericExtractor +from .i3pisaextractor import I3PISAExtractor +from .i3ntmuonlabelsextractor import I3NTMuonLabelExtractor +from .i3quesoextractor import I3QUESOExtractor diff --git a/src/graphnet/data/extractors/icecube/i3extractor.py b/src/graphnet/data/extractors/icecube/i3extractor.py new file mode 100644 index 000000000..3f2fc92d2 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3extractor.py @@ -0,0 +1,92 @@ +"""Base I3Extractor class(es).""" + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Optional + +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors import Extractor + +if has_icecube_package() or TYPE_CHECKING: + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class I3Extractor(Extractor): + """Base class for extracting information from physics I3-frames. + + Contains functionality required to extract data from i3 files, used by + the IceCube Neutrino Observatory. + + All classes inheriting from `I3Extractor` should implement the `__call__` + method. + """ + + def __init__(self, extractor_name: str): + """Construct I3Extractor. + + Args: + extractor_name: Name of the `I3Extractor` instance. Used to keep track of the + provenance of different data, and to name tables to which this + data is saved. + """ + # Member variable(s) + self._i3_file: str = "" + self._gcd_file: str = "" + self._gcd_dict: Dict[int, Any] = {} + self._calibration: Optional["icetray.I3Frame.Calibration"] = None + + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def set_gcd(self, i3_file: str, gcd_file: Optional[str] = None) -> None: + """Extract GFrame and CFrame from i3/gcd-file pair. + + Information from these frames will be set as member variables of + `I3Extractor.` + + Args: + i3_file: Path to i3 file that is being converted. + gcd_file: Path to GCD file. Defaults to None. If no GCD file is + given, the method will attempt to find C and G frames in + the i3 file instead. If either one of those are not + present, `RuntimeErrors` will be raised. + """ + if gcd_file is None: + # If no GCD file is provided, search the I3 file for frames + # containing geometry (GFrame) and calibration (CFrame) + gcd = dataio.I3File(i3_file) + else: + # Ideally ends here + gcd = dataio.I3File(gcd_file) + + # Get GFrame + try: + g_frame = gcd.pop_frame(icetray.I3Frame.Geometry) + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a G-Frame in it. + except RuntimeError as e: + self.error( + "No GCD file was provided " + f"and no G-frame was found in {i3_file.split('/')[-1]}." + ) + raise e + + # Get CFrame + try: + c_frame = gcd.pop_frame(icetray.I3Frame.Calibration) + # If the line above fails, it means that no gcd file was given + # and that the i3 file does not have a C-Frame in it. + except RuntimeError as e: + self.warning( + "No GCD file was provided and no C-frame " + f"was found in {i3_file.split('/')[-1]}." + ) + raise e + + # Save information as member variables of I3Extractor + self._gcd_dict = g_frame["I3Geometry"].omgeo + self._calibration = c_frame["I3Calibration"] + + @abstractmethod + def __call__(self, frame: "icetray.I3Frame") -> dict: + """Extract information from frame.""" + pass diff --git a/src/graphnet/data/extractors/i3featureextractor.py b/src/graphnet/data/extractors/icecube/i3featureextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3featureextractor.py rename to src/graphnet/data/extractors/icecube/i3featureextractor.py index f1f578453..258bb368c 100644 --- a/src/graphnet/data/extractors/i3featureextractor.py +++ b/src/graphnet/data/extractors/icecube/i3featureextractor.py @@ -1,17 +1,14 @@ """I3Extractor class(es) for extracting specific, reconstructed features.""" from typing import TYPE_CHECKING, Any, Dict, List -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package if has_icecube_package() or TYPE_CHECKING: - from icecube import ( - icetray, - dataclasses, - ) # pyright: reportMissingImports=false + from icecube import icetray # pyright: reportMissingImports=false class I3FeatureExtractor(I3Extractor): diff --git a/src/graphnet/data/extractors/i3genericextractor.py b/src/graphnet/data/extractors/icecube/i3genericextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3genericextractor.py rename to src/graphnet/data/extractors/icecube/i3genericextractor.py index 6a86303e7..c79b7329b 100644 --- a/src/graphnet/data/extractors/i3genericextractor.py +++ b/src/graphnet/data/extractors/icecube/i3genericextractor.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.types import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.types import ( cast_object_to_pure_python, cast_pulse_series_to_pure_python, ) -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, serialise, flatten_nested_dictionary, @@ -45,6 +45,7 @@ def __init__( self, keys: Optional[Union[str, List[str]]] = None, exclude_keys: Optional[Union[str, List[str]]] = None, + extractor_name: str = GENERIC_EXTRACTOR_NAME, ): """Construct I3GenericExtractor. @@ -72,7 +73,7 @@ def __init__( self._exclude_keys: Optional[List[str]] = exclude_keys # Base class constructor - super().__init__(GENERIC_EXTRACTOR_NAME) + super().__init__(extractor_name) def _get_keys(self, frame: "icetray.I3Frame") -> List[str]: """Get the list of keys to be queried from `frame`. @@ -170,6 +171,12 @@ def __call__(self, frame: "icetray.I3Frame") -> Dict[str, Any]: # Flatten all other objects else: results[key] = self._flatten_result(result) + if ( + isinstance(results[key], dict) + and "value" in results[key] + and len(results[key]) == 1 + ): + results[key] = results[key]["value"] # Serialise list of iterables to JSON results = {key: serialise(value) for key, value in results.items()} diff --git a/src/graphnet/data/extractors/i3hybridrecoextractor.py b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3hybridrecoextractor.py rename to src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py index 74f445120..90525bcab 100644 --- a/src/graphnet/data/extractors/i3hybridrecoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3hybridrecoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py similarity index 96% rename from src/graphnet/data/extractors/i3ntmuonlabelsextractor.py rename to src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py index 1ca3e8bcb..039b13cfe 100644 --- a/src/graphnet/data/extractors/i3ntmuonlabelsextractor.py +++ b/src/graphnet/data/extractors/icecube/i3ntmuonlabelsextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/icecube/i3particleextractor.py b/src/graphnet/data/extractors/icecube/i3particleextractor.py new file mode 100644 index 000000000..a50c11d21 --- /dev/null +++ b/src/graphnet/data/extractors/icecube/i3particleextractor.py @@ -0,0 +1,44 @@ +"""I3Extractor class(es) for extracting I3Particle properties.""" + +from typing import TYPE_CHECKING, Dict + +from graphnet.data.extractors.icecube import I3Extractor + +if TYPE_CHECKING: + from icecube import icetray # pyright: reportMissingImports=false + + +class I3ParticleExtractor(I3Extractor): + """Class for extracting I3Particle properties. + + Can be used to extract predictions from other algorithms for comparisons + with GraphNeT. + """ + + def __init__(self, extractor_name: str): + """Construct I3ParticleExtractor.""" + # Base class constructor + super().__init__(extractor_name=extractor_name) + + def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]: + """Extract I3Particle properties from I3Particle in frame.""" + output = {} + name = self._extractor_name + if name in frame: + output.update( + { + "zenith_" + name: frame[name].dir.zenith, + "azimuth_" + name: frame[name].dir.azimuth, + "dir_x_" + name: frame[name].dir.x, + "dir_y_" + name: frame[name].dir.y, + "dir_z_" + name: frame[name].dir.z, + "pos_x_" + name: frame[name].pos.x, + "pos_y_" + name: frame[name].pos.y, + "pos_z_" + name: frame[name].pos.z, + "time_" + name: frame[name].time, + "speed_" + name: frame[name].speed, + "energy_" + name: frame[name].energy, + } + ) + + return output diff --git a/src/graphnet/data/extractors/i3pisaextractor.py b/src/graphnet/data/extractors/icecube/i3pisaextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3pisaextractor.py rename to src/graphnet/data/extractors/icecube/i3pisaextractor.py index fd5a09583..f14a8046a 100644 --- a/src/graphnet/data/extractors/i3pisaextractor.py +++ b/src/graphnet/data/extractors/icecube/i3pisaextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3quesoextractor.py b/src/graphnet/data/extractors/icecube/i3quesoextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3quesoextractor.py rename to src/graphnet/data/extractors/icecube/i3quesoextractor.py index b72b20046..e29c72a41 100644 --- a/src/graphnet/data/extractors/i3quesoextractor.py +++ b/src/graphnet/data/extractors/icecube/i3quesoextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube.i3extractor import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3retroextractor.py b/src/graphnet/data/extractors/icecube/i3retroextractor.py similarity index 97% rename from src/graphnet/data/extractors/i3retroextractor.py rename to src/graphnet/data/extractors/icecube/i3retroextractor.py index cd55d01f4..aaeb773b4 100644 --- a/src/graphnet/data/extractors/i3retroextractor.py +++ b/src/graphnet/data/extractors/icecube/i3retroextractor.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3splinempeextractor.py b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py similarity index 93% rename from src/graphnet/data/extractors/i3splinempeextractor.py rename to src/graphnet/data/extractors/icecube/i3splinempeextractor.py index e47b2e71d..1439ada51 100644 --- a/src/graphnet/data/extractors/i3splinempeextractor.py +++ b/src/graphnet/data/extractors/icecube/i3splinempeextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/i3truthextractor.py b/src/graphnet/data/extractors/icecube/i3truthextractor.py similarity index 99% rename from src/graphnet/data/extractors/i3truthextractor.py rename to src/graphnet/data/extractors/icecube/i3truthextractor.py index bcfe694c3..b715e57ab 100644 --- a/src/graphnet/data/extractors/i3truthextractor.py +++ b/src/graphnet/data/extractors/icecube/i3truthextractor.py @@ -4,8 +4,8 @@ import matplotlib.path as mpath from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from graphnet.data.extractors.i3extractor import I3Extractor -from graphnet.data.extractors.utilities.frames import ( +from .i3extractor import I3Extractor +from .utilities.frames import ( frame_is_montecarlo, frame_is_noise, ) diff --git a/src/graphnet/data/extractors/i3tumextractor.py b/src/graphnet/data/extractors/icecube/i3tumextractor.py similarity index 94% rename from src/graphnet/data/extractors/i3tumextractor.py rename to src/graphnet/data/extractors/icecube/i3tumextractor.py index 38cbca146..685b0a78e 100644 --- a/src/graphnet/data/extractors/i3tumextractor.py +++ b/src/graphnet/data/extractors/icecube/i3tumextractor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict -from graphnet.data.extractors.i3extractor import I3Extractor +from graphnet.data.extractors.icecube import I3Extractor if TYPE_CHECKING: from icecube import icetray # pyright: reportMissingImports=false diff --git a/src/graphnet/data/extractors/utilities/__init__.py b/src/graphnet/data/extractors/icecube/utilities/__init__.py similarity index 100% rename from src/graphnet/data/extractors/utilities/__init__.py rename to src/graphnet/data/extractors/icecube/utilities/__init__.py diff --git a/src/graphnet/data/extractors/utilities/collections.py b/src/graphnet/data/extractors/icecube/utilities/collections.py similarity index 100% rename from src/graphnet/data/extractors/utilities/collections.py rename to src/graphnet/data/extractors/icecube/utilities/collections.py diff --git a/src/graphnet/data/extractors/utilities/frames.py b/src/graphnet/data/extractors/icecube/utilities/frames.py similarity index 100% rename from src/graphnet/data/extractors/utilities/frames.py rename to src/graphnet/data/extractors/icecube/utilities/frames.py diff --git a/src/graphnet/data/filters.py b/src/graphnet/data/extractors/icecube/utilities/i3_filters.py similarity index 100% rename from src/graphnet/data/filters.py rename to src/graphnet/data/extractors/icecube/utilities/i3_filters.py diff --git a/src/graphnet/data/extractors/utilities/types.py b/src/graphnet/data/extractors/icecube/utilities/types.py similarity index 98% rename from src/graphnet/data/extractors/utilities/types.py rename to src/graphnet/data/extractors/icecube/utilities/types.py index cf58e8357..32ecae0ff 100644 --- a/src/graphnet/data/extractors/utilities/types.py +++ b/src/graphnet/data/extractors/icecube/utilities/types.py @@ -4,11 +4,11 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from graphnet.data.extractors.utilities.collections import ( +from graphnet.data.extractors.icecube.utilities.collections import ( transpose_list_of_dicts, flatten_nested_dictionary, ) -from graphnet.data.extractors.utilities.frames import ( +from graphnet.data.extractors.icecube.utilities.frames import ( get_om_keys_and_pulseseries, ) from graphnet.utilities.imports import has_icecube_package diff --git a/src/graphnet/data/parquet/__init__.py b/src/graphnet/data/parquet/__init__.py index 616d89c16..2c41ca75d 100644 --- a/src/graphnet/data/parquet/__init__.py +++ b/src/graphnet/data/parquet/__init__.py @@ -1,2 +1,2 @@ -"""Parquet-specific implementation of data classes.""" -from .parquet_dataconverter import ParquetDataConverter +"""Module for deprecated parquet methods.""" +from .deprecated_methods import ParquetDataConverter diff --git a/src/graphnet/data/parquet/deprecated_methods.py b/src/graphnet/data/parquet/deprecated_methods.py new file mode 100644 index 000000000..423e1aa00 --- /dev/null +++ b/src/graphnet/data/parquet/deprecated_methods.py @@ -0,0 +1,60 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, +) +from graphnet.data import I3ToParquetConverter + + +class ParquetDataConverter(I3ToParquetConverter): + """Method for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + extractors=extractors, + num_workers=workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToParquetConverter instead." + ) diff --git a/src/graphnet/data/parquet/parquet_dataconverter.py b/src/graphnet/data/parquet/parquet_dataconverter.py deleted file mode 100644 index 68531c8e2..000000000 --- a/src/graphnet/data/parquet/parquet_dataconverter.py +++ /dev/null @@ -1,52 +0,0 @@ -"""DataConverter for the Parquet backend.""" - -from collections import OrderedDict -import os -from typing import List, Optional - -import awkward - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] - - -class ParquetDataConverter(DataConverter): - """Class for converting I3-files to Parquet format.""" - - # Class variables - file_suffix: str = "parquet" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to parquet file.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Overwriting." - ) - - self.debug(f"Saving to {output_file}") - self.debug( - f"- Data has {len(data)} events and {len(data[0])} tables for each" - ) - - awkward.to_parquet(awkward.from_iter(data), output_file) - - self.debug("- Done saving") - self._output_files.append(output_file) - - def merge_files( - self, output_file: str, input_files: Optional[List[str]] = None - ) -> None: - """Parquet-specific method for merging output files. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files to be merged, according to the - specific implementation. Default to None, meaning that all - files output by the current instance are merged. - - Raises: - NotImplementedError: If the method has not been implemented for the - Parquet backend. - """ - raise NotImplementedError() diff --git a/src/graphnet/data/pipeline.py b/src/graphnet/data/pipeline.py index d97415bb0..9973c763f 100644 --- a/src/graphnet/data/pipeline.py +++ b/src/graphnet/data/pipeline.py @@ -13,7 +13,9 @@ import torch from torch.utils.data import DataLoader -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.training.utils import get_predictions, make_dataloader from graphnet.models.graphs import GraphDefinition diff --git a/src/graphnet/data/pre_configured/__init__.py b/src/graphnet/data/pre_configured/__init__.py new file mode 100644 index 000000000..f56f0de18 --- /dev/null +++ b/src/graphnet/data/pre_configured/__init__.py @@ -0,0 +1,2 @@ +"""Module for pre-configured converter modules.""" +from .dataconverters import I3ToParquetConverter, I3ToSQLiteConverter diff --git a/src/graphnet/data/pre_configured/dataconverters.py b/src/graphnet/data/pre_configured/dataconverters.py new file mode 100644 index 000000000..6db89c46e --- /dev/null +++ b/src/graphnet/data/pre_configured/dataconverters.py @@ -0,0 +1,99 @@ +"""Pre-configured combinations of writers and readers.""" + +from typing import List, Union, Type + +from graphnet.data import DataConverter +from graphnet.data.readers import I3Reader +from graphnet.data.writers import ParquetWriter, SQLiteWriter +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import I3Filter + + +class I3ToParquetConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to parquet files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=ParquetWriter(), + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) + + +class I3ToSQLiteConverter(DataConverter): + """Preconfigured DataConverter for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + num_workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to SQLite. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + num_workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + file_reader=I3Reader(gcd_rescue=gcd_rescue, i3_filters=i3_filters), + save_method=SQLiteWriter(), + extractors=extractors, + num_workers=num_workers, + index_column=index_column, + outdir=outdir, + ) diff --git a/src/graphnet/data/readers/__init__.py b/src/graphnet/data/readers/__init__.py new file mode 100644 index 000000000..0755bd35a --- /dev/null +++ b/src/graphnet/data/readers/__init__.py @@ -0,0 +1,3 @@ +"""Modules for reading experiment-specific data and applying Extractors.""" +from .graphnet_file_reader import GraphNeTFileReader +from .i3reader import I3Reader diff --git a/src/graphnet/data/readers/graphnet_file_reader.py b/src/graphnet/data/readers/graphnet_file_reader.py new file mode 100644 index 000000000..c590c6424 --- /dev/null +++ b/src/graphnet/data/readers/graphnet_file_reader.py @@ -0,0 +1,142 @@ +"""Module containing different FileReader classes in GraphNeT. + +These methods are used to open and apply `Extractors` to experiment-specific +file formats. +""" + +from typing import List, Union, OrderedDict, Any +from abc import abstractmethod, ABC +import glob +import os + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger +from graphnet.data.dataclasses import I3FileSet +from graphnet.data.extractors.extractor import Extractor +from graphnet.data.extractors.icecube import I3Extractor + + +class GraphNeTFileReader(Logger, ABC): + """A generic base class for FileReaders in GraphNeT. + + Classes inheriting from `GraphNeTFileReader` must implement a + `__call__` method that opens a file, applies `Extractor`(s) and returns + a list of ordered dictionaries. + + In addition, Classes inheriting from `GraphNeTFileReader` must set + class properties `accepted_file_extensions` and `accepted_extractors`. + """ + + _accepted_file_extensions: List[str] = [] + _accepted_extractors: List[Any] = [] + + @abstractmethod + def __call__(self, file_path: Union[str, I3FileSet]) -> List[OrderedDict]: + """Open and apply extractors to a single file. + + The `output` must be a list of dictionaries, where the number of events + in the file `n_events` satisfies `len(output) = n_events`. I.e each + element in the list is a dictionary, and each field in the dictionary + is the output of a single extractor. + """ + + @property + def accepted_file_extensions(self) -> List[str]: + """Return list of accepted file extensions.""" + return self._accepted_file_extensions + + @property + def accepted_extractors(self) -> List[Extractor]: + """Return list of compatible `Extractor`(s).""" + return self._accepted_extractors + + @property + def extracor_names(self) -> List[str]: + """Return list of table names produced by extractors.""" + return [extractor.name for extractor in self._extractors] + + def find_files( + self, path: Union[str, List[str]] + ) -> Union[List[str], List[I3FileSet]]: + """Search directory for input files recursively. + + This method may be overwritten by custom implementations. + + Args: + path: path to directory. + + Returns: + List of files matching accepted file extensions. + """ + if isinstance(path, str): + path = [path] + files = [] + for dir in path: + for accepted_file_extension in self.accepted_file_extensions: + files.extend(glob.glob(dir + f"/*{accepted_file_extension}")) + + # Check that files are OK. + self.validate_files(files) + return files + + @final + def set_extractors( + self, extractors: Union[List[Extractor], List[I3Extractor]] + ) -> None: + """Set `Extractor`(s) as member variable. + + Args: + extractors: A list of `Extractor`(s) to set as member variable. + """ + if not isinstance(extractors, list): + extractors = [extractors] + self._validate_extractors(extractors) + self._extractors = extractors + + @final + def _validate_extractors( + self, extractors: Union[List[Extractor], List[I3Extractor]] + ) -> None: + for extractor in extractors: + try: + assert isinstance(extractor, tuple(self.accepted_extractors)) # type: ignore + except AssertionError as e: + self.error( + f"{extractor.__class__.__name__}" + f" is not supported by {self.__class__.__name__}" + ) + raise e + + @final + def validate_files( + self, input_files: Union[List[str], List[I3FileSet]] + ) -> None: + """Check that the input files are accepted by the reader. + + Args: + input_files: Path(s) to input file(s). + """ + for input_file in input_files: + # Handle filepath vs. FileSet cases + if isinstance(input_file, I3FileSet): + self._validate_file(input_file.i3_file) + self._validate_file(input_file.gcd_file) + else: + self._validate_file(input_file) + + @final + def _validate_file(self, file: str) -> None: + """Validate a single file path. + + Args: + file: path to file. + + Returns: + None + """ + try: + assert file.lower().endswith(tuple(self.accepted_file_extensions)) + except AssertionError: + self.error( + f'{self.__class__.__name__} accepts {self.accepted_file_extensions} but {file.split("/")[-1]} has extension {os.path.splitext(file)[1]}.' + ) diff --git a/src/graphnet/data/readers/i3reader.py b/src/graphnet/data/readers/i3reader.py new file mode 100644 index 000000000..ed5fd7c1f --- /dev/null +++ b/src/graphnet/data/readers/i3reader.py @@ -0,0 +1,137 @@ +"""Module containing different I3Reader.""" + +from typing import List, Union, OrderedDict, Type + +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.dataclasses import I3FileSet +from graphnet.utilities.filesys import find_i3_files +from .graphnet_file_reader import GraphNeTFileReader + + +if has_icecube_package(): + from icecube import icetray, dataio # pyright: reportMissingImports=false + + +class I3Reader(GraphNeTFileReader): + """A class for reading .i3 files from the IceCube Neutrino Observatory. + + Note that this class relies on IceCube-specific software, and therefore + must be run in a software environment that contains IceTray. + """ + + def __init__( + self, + gcd_rescue: str, + i3_filters: Union[I3Filter, List[I3Filter]] = None, + icetray_verbose: int = 0, + ): + """Initialize `I3Reader`. + + Args: + gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + """ + # Set verbosity + if icetray_verbose == 0: + icetray.I3Logger.global_logger = icetray.I3NullLogger() + + if i3_filters is None: + i3_filters = [NullSplitI3Filter()] + # Set Member Variables + self._accepted_file_extensions = [".bz2", ".zst", ".gz"] + self._accepted_extractors = [I3Extractor] + self._gcd_rescue = gcd_rescue + self._i3filters = ( + i3_filters if isinstance(i3_filters, list) else [i3_filters] + ) + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def __call__(self, file_path: I3FileSet) -> List[OrderedDict]: # type: ignore + """Extract data from single I3 file. + + Args: + fileset: Path to I3 file and corresponding GCD file. + + Returns: + Extracted data. + """ + # Set I3-GCD file pair in extractor + for extractor in self._extractors: + assert isinstance(extractor, I3Extractor) + extractor.set_gcd( + i3_file=file_path.i3_file, gcd_file=file_path.gcd_file + ) + + # Open I3 file + i3_file_io = dataio.I3File(file_path.i3_file, "r") + data = list() + while i3_file_io.more(): + try: + frame = i3_file_io.pop_physics() + except Exception as e: + if "I3" in str(e): + continue + # check if frame should be skipped + if self._skip_frame(frame): + continue + + # Try to extract data from I3Frame + results = [extractor(frame) for extractor in self._extractors] + + data_dict = OrderedDict(zip(self.extracor_names, results)) + + data.append(data_dict) + return data + + def find_files(self, path: Union[str, List[str]]) -> List[I3FileSet]: + """Recursively search directory for I3 and GCD file pairs. + + Args: + path: directory to search recursively. + + Returns: + List I3 and GCD file pairs as I3FileSets + """ + # Find all I3 and GCD files in the specified directories. + i3_files, gcd_files = find_i3_files( + path, + self._gcd_rescue, + ) + + # Pack as I3FileSets + filesets = [ + I3FileSet(i3_file, gcd_file) + for i3_file, gcd_file in zip(i3_files, gcd_files) + ] + return filesets + + def _skip_frame(self, frame: "icetray.I3Frame") -> bool: + """Check the user defined filters. + + Returns: + bool: True if frame should be skipped, False otherwise. + """ + if self._i3filters is None: + return False # No filters defined, so we keep the frame + + for filter in self._i3filters: + if not filter(frame): + return True # keep_frame call false, skip the frame. + return False # All filter keep_frame calls true, keep the frame. diff --git a/src/graphnet/data/sqlite/__init__.py b/src/graphnet/data/sqlite/__init__.py index e4ac554a7..436a86f2d 100644 --- a/src/graphnet/data/sqlite/__init__.py +++ b/src/graphnet/data/sqlite/__init__.py @@ -1,4 +1,2 @@ -"""SQLite-specific implementation of data classes.""" -from .sqlite_dataconverter import SQLiteDataConverter -from .sqlite_utilities import create_table_and_save_to_sql -from .sqlite_utilities import run_sql_code, save_to_sql +"""Module for deprecated methods using sqlite.""" +from .deprecated_methods import SQLiteDataConverter diff --git a/src/graphnet/data/sqlite/deprecated_methods.py b/src/graphnet/data/sqlite/deprecated_methods.py new file mode 100644 index 000000000..30b563c59 --- /dev/null +++ b/src/graphnet/data/sqlite/deprecated_methods.py @@ -0,0 +1,62 @@ +"""Module containing deprecated data conversion code. + +This code will be removed in GraphNeT 2.0. +""" + +from typing import List, Union, Type + +from graphnet.data.extractors.icecube import I3Extractor +from graphnet.data.extractors.icecube.utilities.i3_filters import ( + I3Filter, + NullSplitI3Filter, +) +from graphnet.data import I3ToSQLiteConverter + + +class SQLiteDataConverter(I3ToSQLiteConverter): + """Method for converting i3 files to SQLite files.""" + + def __init__( + self, + gcd_rescue: str, + extractors: List[I3Extractor], + outdir: str, + index_column: str = "event_no", + workers: int = 1, + i3_filters: Union[I3Filter, List[I3Filter]] = None, # type: ignore + ): + """Convert I3 files to Parquet. + + Args: + gcd_rescue: gcd_rescue: Path to a GCD file that will be used if no GCD file is + found in subfolder. `I3Reader` will recursively search + the input directory for I3-GCD file pairs. By IceCube + convention, a folder containing i3 files will have an + accompanying GCD file. However, in some cases, this + convention is broken. In cases where a folder contains + i3 files but no GCD file, the `gcd_rescue` is used + instead. + extractors: The `Extractor`(s) that will be applied to the input + files. + outdir: The directory to save the files in. + icetray_verbose: Set the level of verbosity of icetray. + Defaults to 0. + index_column: Name of the event id column added to the events. + Defaults to "event_no". + workers: The number of CPUs used for parallel processing. + Defaults to 1 (no multiprocessing). + i3_filters: Instances of `I3Filter` to filter PFrames. Defaults to + `NullSplitI3Filter`. + """ + super().__init__( + extractors=extractors, + num_workers=workers, + index_column=index_column, + i3_filters=i3_filters, + outdir=outdir, + gcd_rescue=gcd_rescue, + ) + self.warning( + f"{self.__class__.__name__} will be deprecated in " + "GraphNeT 2.0. Please use I3ToSQLiteConverter instead." + ) diff --git a/src/graphnet/data/sqlite/sqlite_dataconverter.py b/src/graphnet/data/sqlite/sqlite_dataconverter.py deleted file mode 100644 index 1750b7a33..000000000 --- a/src/graphnet/data/sqlite/sqlite_dataconverter.py +++ /dev/null @@ -1,349 +0,0 @@ -"""DataConverter for the SQLite backend.""" - -from collections import OrderedDict -import os -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import sqlalchemy -import sqlite3 -from tqdm import tqdm - -from graphnet.data.dataconverter import DataConverter # type: ignore[attr-defined] -from graphnet.data.sqlite.sqlite_utilities import ( - create_table, - create_table_and_save_to_sql, -) - - -class SQLiteDataConverter(DataConverter): - """Class for converting I3-file(s) to SQLite format.""" - - # Class variables - file_suffix = "db" - - # Abstract method implementation(s) - def save_data(self, data: List[OrderedDict], output_file: str) -> None: - """Save data to SQLite database.""" - # Check(s) - if os.path.exists(output_file): - self.warning( - f"Output file {output_file} already exists. Appending." - ) - - # Concatenate data - if len(data) == 0: - self.warning( - "No data was extracted from the processed I3 file(s). " - f"No data saved to {output_file}" - ) - return - - saved_any = False - dataframe_list: OrderedDict = OrderedDict( - [(key, []) for key in data[0]] - ) - for data_dict in data: - for key, data_values in data_dict.items(): - df = construct_dataframe(data_values) - - if self.any_pulsemap_is_non_empty(data_dict) and len(df) > 0: - # only include data_dict in temp. databases if at least one pulsemap is non-empty, - # and the current extractor (df) is also non-empty (also since truth is always non-empty) - dataframe_list[key].append(df) - - dataframe = OrderedDict( - [ - ( - key, - pd.concat(dfs, ignore_index=True, sort=True) - if dfs - else pd.DataFrame(), - ) - for key, dfs in dataframe_list.items() - ] - ) - # Can delete dataframe_list here to free up memory. - - # Save each dataframe to SQLite database - self.debug(f"Saving to {output_file}") - for table, df in dataframe.items(): - if len(df) > 0: - create_table_and_save_to_sql( - df, - table, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table) or is_mc_tree(table) - ), - ) - saved_any = True - - if saved_any: - self.debug("- Done saving") - else: - self.warning(f"No data saved to {output_file}") - - def merge_files( - self, - output_file: str, - input_files: Optional[List[str]] = None, - max_table_size: Optional[int] = None, - ) -> None: - """SQLite-specific method for merging output files/databases. - - Args: - output_file: Name of the output file containing the merged results. - input_files: Intermediate files/databases to be merged, according - to the specific implementation. Default to None, meaning that - all files/databases output by the current instance are merged. - max_table_size: The maximum number of rows in any given table. - If any one table exceed this limit, a new database will be - created. - """ - if max_table_size: - self.warning( - f"Merging got max_table_size of {max_table_size}. Will attempt to create databases with a maximum row count of this size." - ) - self.max_table_size = max_table_size - self._partition_count = 1 - - if input_files is None: - self.info("Merging files output by current instance.") - self._input_files = self._output_files - else: - self._input_files = input_files - - if not output_file.endswith("." + self.file_suffix): - output_file = ".".join([output_file, self.file_suffix]) - - if os.path.exists(output_file): - self.warning( - f"Target path for merged database, {output_file}, already exists." - ) - - if len(self._input_files) > 0: - self.info(f"Merging {len(self._input_files)} database files") - # Create one empty database table for each extraction - self._merged_table_names = self._extract_table_names( - self._input_files - ) - if self.max_table_size: - output_file = self._adjust_output_file_name(output_file) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - # Merge temporary databases into newly created one - self._merge_temporary_databases(output_file, self._input_files) - else: - self.warning("No temporary database files found!") - - # Internal methods - def _adjust_output_file_name(self, output_file: str) -> str: - if "_part_" in output_file: - root = ( - output_file.split("_part_")[0] - + output_file.split("_part_")[1][1:] - ) - else: - root = output_file - str_list = root.split(".db") - return str_list[0] + f"_part_{self._partition_count}" + ".db" - - def _update_row_counts( - self, results: "OrderedDict[str, pd.DataFrame]" - ) -> None: - for table_name, data in results.items(): - self._row_counts[table_name] += len(data) - return - - def _initialize_row_counts(self) -> Dict[str, int]: - """Build dictionary with row counts. Initialized with 0. - - Returns: - Dictionary where every field is a table name that contains - corresponding row counts. - """ - row_counts = {} - for table_name in self._merged_table_names: - row_counts[table_name] = 0 - return row_counts - - def _create_empty_tables(self, output_file: str) -> None: - """Create tables for output database. - - Args: - output_file: Path to database. - """ - for table_name in self._merged_table_names: - column_names = self._extract_column_names( - self._input_files, table_name - ) - if len(column_names) > 1: - create_table( - column_names, - table_name, - output_file, - default_type="FLOAT", - integer_primary_key=not ( - is_pulse_map(table_name) or is_mc_tree(table_name) - ), - ) - - def _get_tables_in_database(self, db: str) -> Tuple[str, ...]: - with sqlite3.connect(db) as conn: - table_names = tuple( - [ - p[0] - for p in ( - conn.execute( - "SELECT name FROM sqlite_master WHERE type='table';" - ).fetchall() - ) - ] - ) - return table_names - - def _extract_table_names( - self, db: Union[str, List[str]] - ) -> Tuple[str, ...]: - """Get the names of all tables in database `db`.""" - if isinstance(db, str): - db = [db] - results = [self._get_tables_in_database(path) for path in db] - # @TODO: Check... - if all([results[0] == r for r in results]): - return results[0] - else: - unique_tables = [] - for tables in results: - for table in tables: - if table not in unique_tables: - unique_tables.append(table) - return tuple(unique_tables) - - def _extract_column_names( - self, db_paths: List[str], table_name: str - ) -> List[str]: - for db_path in db_paths: - tables_in_database = self._get_tables_in_database(db_path) - if table_name in tables_in_database: - with sqlite3.connect(db_path) as con: - query = f"select * from {table_name} limit 1" - columns = pd.read_sql(query, con).columns - if len(columns): - return columns - return [] - - def any_pulsemap_is_non_empty(self, data_dict: Dict[str, Dict]) -> bool: - """Check whether there are non-empty pulsemaps extracted from P frame. - - Takes in the data extracted from the P frame, then retrieves the - values, if there are any, from the pulsemap key(s) (e.g - SplitInIcePulses). If at least one of the pulsemaps is non-empty then - return true. If no pulsemaps exist, i.e., if no `I3FeatureExtractor` is - called e.g. because `I3GenericExtractor` is used instead, always return - True. - """ - if len(self._pulsemaps) == 0: - return True - - pulsemap_dicts = [data_dict[pulsemap] for pulsemap in self._pulsemaps] - return any(d["dom_x"] for d in pulsemap_dicts) - - def _submit_to_database( - self, database: str, key: str, data: pd.DataFrame - ) -> None: - """Submit data to the database with specified key.""" - if len(data) == 0: - self.info(f"No data provided for {key}.") - return - engine = sqlalchemy.create_engine("sqlite:///" + database) - data.to_sql(key, engine, index=False, if_exists="append") - engine.dispose() - - def _extract_everything(self, db: str) -> "OrderedDict[str, pd.DataFrame]": - """Extract everything from the temporary database `db`. - - Args: - db: Path to temporary database. - - Returns: - Dictionary containing the data for each extracted table. - """ - results = OrderedDict() - table_names = self._extract_table_names(db) - with sqlite3.connect(db) as conn: - for table_name in table_names: - query = f"select * from {table_name}" - try: - data = pd.read_sql(query, conn) - except: # noqa: E722 - data = [] - results[table_name] = data - return results - - def _merge_temporary_databases( - self, - output_file: str, - input_files: List[str], - ) -> None: - """Merge the temporary databases. - - Args: - output_file: path to the final database - input_files: list of names of temporary databases - """ - file_count = 0 - for input_file in tqdm(input_files, colour="green"): - results = self._extract_everything(input_file) - for table_name, data in results.items(): - self._submit_to_database(output_file, table_name, data) - file_count += 1 - if (self.max_table_size is not None) & ( - file_count < len(input_files) - ): - self._update_row_counts(results) - maximum_row_count_reached = False - for table in self._row_counts.keys(): - assert self.max_table_size is not None - if self._row_counts[table] >= self.max_table_size: - maximum_row_count_reached = True - if maximum_row_count_reached: - self._partition_count += 1 - output_file = self._adjust_output_file_name(output_file) - self.info( - f"Maximum row count reached. Creating new partition at {output_file}" - ) - self._create_empty_tables(output_file) - self._row_counts = self._initialize_row_counts() - - -# Implementation-specific utility function(s) -def construct_dataframe(extraction: Dict[str, Any]) -> pd.DataFrame: - """Convert extraction to pandas.DataFrame. - - Args: - extraction: Dictionary with the extracted data. - - Returns: - Extraction as pandas.DataFrame. - """ - all_scalars = True - for value in extraction.values(): - if isinstance(value, (list, tuple, dict)): - all_scalars = False - break - - out = pd.DataFrame(extraction, index=[0] if all_scalars else None) - return out - - -def is_pulse_map(table_name: str) -> bool: - """Check whether `table_name` corresponds to a pulse map.""" - return "pulse" in table_name.lower() or "series" in table_name.lower() - - -def is_mc_tree(table_name: str) -> bool: - """Check whether `table_name` corresponds to an MC tree.""" - return "I3MCTree" in table_name diff --git a/src/graphnet/data/utilities/__init__.py b/src/graphnet/data/utilities/__init__.py index 0dd9e0600..ad4f0c7db 100644 --- a/src/graphnet/data/utilities/__init__.py +++ b/src/graphnet/data/utilities/__init__.py @@ -1 +1,4 @@ """Utilities for use across `graphnet.data`.""" +from .sqlite_utilities import create_table_and_save_to_sql +from .sqlite_utilities import get_primary_keys +from .sqlite_utilities import query_database diff --git a/src/graphnet/data/utilities/parquet_to_sqlite.py b/src/graphnet/data/utilities/parquet_to_sqlite.py index 146e69ce8..11114698e 100644 --- a/src/graphnet/data/utilities/parquet_to_sqlite.py +++ b/src/graphnet/data/utilities/parquet_to_sqlite.py @@ -9,7 +9,9 @@ import pandas as pd from tqdm.auto import trange -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/src/graphnet/data/sqlite/sqlite_utilities.py b/src/graphnet/data/utilities/sqlite_utilities.py similarity index 72% rename from src/graphnet/data/sqlite/sqlite_utilities.py rename to src/graphnet/data/utilities/sqlite_utilities.py index 23bae802d..cfa308ba2 100644 --- a/src/graphnet/data/sqlite/sqlite_utilities.py +++ b/src/graphnet/data/utilities/sqlite_utilities.py @@ -1,7 +1,7 @@ """SQLite-specific utility functions for use in `graphnet.data`.""" import os.path -from typing import List +from typing import List, Dict, Tuple import pandas as pd import sqlalchemy @@ -16,6 +16,58 @@ def database_exists(database_path: str) -> bool: return os.path.exists(database_path) +def query_database(database: str, query: str) -> pd.DataFrame: + """Execute query on database, and return result. + + Args: + database: path to database. + query: query to be executed. + + Returns: + DataFrame containing the result of the query. + """ + with sqlite3.connect(database) as conn: + return pd.read_sql(query, conn) + + +def get_primary_keys(database: str) -> Tuple[Dict[str, str], str]: + """Get name of primary key column for each table in database. + + Args: + database: path to database. + + Returns: + A dictionary containing the names of primary keys in each table of + `database`. E.g. {'truth': "event_no", + 'SplitInIcePulses': None} + Name of the primary key. + """ + with sqlite3.connect(database) as conn: + query = 'SELECT name FROM sqlite_master WHERE type == "table"' + table_names = [table[0] for table in conn.execute(query).fetchall()] + + integer_primary_key = {} + for table in table_names: + query = f"SELECT l.name FROM pragma_table_info('{table}') as l WHERE l.pk = 1;" + first_primary_key = [ + key[0] for key in conn.execute(query).fetchall() + ] + integer_primary_key[table] = ( + first_primary_key[0] if len(first_primary_key) else None + ) + + # Get the primary key column name + primary_key_candidates = [] + for val in set(integer_primary_key.values()): + if val is not None: + primary_key_candidates.append(val) + + # There should only be one primary key: + assert len(primary_key_candidates) == 1 + + return integer_primary_key, primary_key_candidates[0] + + def database_table_exists(database_path: str, table_name: str) -> bool: """Check whether `table_name` exists in database at `database_path`.""" if not database_exists(database_path): diff --git a/src/graphnet/data/writers/__init__.py b/src/graphnet/data/writers/__init__.py new file mode 100644 index 000000000..ad3e2748e --- /dev/null +++ b/src/graphnet/data/writers/__init__.py @@ -0,0 +1,4 @@ +"""Modules for saving interim dataformat to various data backends.""" +from .graphnet_writer import GraphNeTWriter +from .parquet_writer import ParquetWriter +from .sqlite_writer import SQLiteWriter diff --git a/src/graphnet/data/writers/graphnet_writer.py b/src/graphnet/data/writers/graphnet_writer.py new file mode 100644 index 000000000..f6ec03029 --- /dev/null +++ b/src/graphnet/data/writers/graphnet_writer.py @@ -0,0 +1,94 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from typing import Dict, List, Union +from abc import abstractmethod, ABC + +from graphnet.utilities.decorators import final +from graphnet.utilities.logging import Logger + +import pandas as pd + + +class GraphNeTWriter(Logger, ABC): + """Generic base class for saving interim data format in `DataConverter`. + + Classes inheriting from `GraphNeTFileSaveMethod` must implement the + `save_file` method, which recieves the interim data format from + from a single file. + + In addition, classes inheriting from `GraphNeTFileSaveMethod` must + set the `file_extension` property. + """ + + @abstractmethod + def _save_file( + self, + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], + output_file_path: str, + n_events: int, + ) -> None: + """Save the interim data format from a single input file. + + Args: + data: the interim data from a single input file. + output_file_path: output file path. + n_events: Number of events container in `data`. + """ + raise NotImplementedError + + @abstractmethod + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """Merge smaller files. + + Args: + files: Files to be merged. + output_dir: The directory to store the merged files in. + """ + raise NotImplementedError + + @final + def __call__( + self, + data: Union[Dict[str, pd.DataFrame], Dict[str, List[pd.DataFrame]]], + file_name: str, + output_dir: str, + n_events: int, + ) -> None: + """Save data. + + Args: + data: data to be saved. + file_name: name of input file. Will be used to generate output + file name. + output_dir: directory to save data to. + n_events: Number of events in `data`. + """ + # make dir + os.makedirs(output_dir, exist_ok=True) + output_file_path = ( + os.path.join(output_dir, file_name) + self.file_extension + ) + + self._save_file( + data=data, output_file_path=output_file_path, n_events=n_events + ) + return + + @property + def file_extension(self) -> str: + """Return file extension used to store the data.""" + return self._file_extension # type: ignore + + @property + def expects_merged_dataframes(self) -> bool: + """Return if writer expects input to be merged dataframes or not.""" + return self._merge_dataframes # type: ignore diff --git a/src/graphnet/data/writers/parquet_writer.py b/src/graphnet/data/writers/parquet_writer.py new file mode 100644 index 000000000..18e524ca9 --- /dev/null +++ b/src/graphnet/data/writers/parquet_writer.py @@ -0,0 +1,51 @@ +"""DataConverter for the Parquet backend.""" + +import os +from typing import List, Optional, Dict + +import awkward +import pandas as pd + +from .graphnet_writer import GraphNeTWriter + + +class ParquetWriter(GraphNeTWriter): + """Class for writing interim data format to Parquet.""" + + # Class variables + _file_extension = ".parquet" + _merge_dataframes = False + + # Abstract method implementation(s) + def _save_file( + self, + data: Dict[str, List[pd.DataFrame]], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to parquet.""" + # Check(s) + + if n_events > 0: + events = [] + for k in range(n_events): + event = {} + for table in data.keys(): + event[table] = data[table][k].to_dict(orient="list") + + events.append(event) + + awkward.to_parquet(awkward.from_iter(events), output_file_path) + + def merge_files(self, files: List[str], output_dir: str) -> None: + """Merge parquet files. + + Args: + files: input files for merging. + output_dir: directory to store merged file(s) in. + + Raises: + NotImplementedError + """ + self.error(f"{self.__class__.__name__} does not have a merge method.") + raise NotImplementedError diff --git a/src/graphnet/data/writers/sqlite_writer.py b/src/graphnet/data/writers/sqlite_writer.py new file mode 100644 index 000000000..ab8d95051 --- /dev/null +++ b/src/graphnet/data/writers/sqlite_writer.py @@ -0,0 +1,226 @@ +"""Module containing `GraphNeTFileSaveMethod`(s). + +These modules are used to save the interim data format from `DataConverter` to +a deep-learning friendly file format. +""" + +import os +from tqdm import tqdm +from typing import List, Dict, Optional + +from graphnet.data.utilities import ( + create_table_and_save_to_sql, + get_primary_keys, + query_database, +) +import pandas as pd +from .graphnet_writer import GraphNeTWriter + + +class SQLiteWriter(GraphNeTWriter): + """A method for saving GraphNeT's interim dataformat to SQLite.""" + + def __init__( + self, + merged_database_name: str = "merged.db", + max_table_size: Optional[int] = None, + ) -> None: + """Initialize `SQLiteWriter`. + + Args: + merged_database_name: name of the database, not path, that files + will be merged into. Defaults to "merged.db". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database). + """ + # Member Variables + self._file_extension = ".db" + self._merge_dataframes = True + self._max_table_size = max_table_size + self._database_name = merged_database_name + + # Add file extension to database name if forgotten + if not self._database_name.endswith(self._file_extension): + self._database_name = self._database_name + self._file_extension + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def _save_file( + self, + data: Dict[str, pd.DataFrame], + output_file_path: str, + n_events: int, + ) -> None: + """Save data to SQLite database.""" + # Check(s) + if os.path.exists(output_file_path): + self.warning( + f"Output file {output_file_path} already exists. Appending." + ) + + # Concatenate data + if len(data) == 0: + self.warning( + "No data was extracted from the processed I3 file(s). " + f"No data saved to {output_file_path}" + ) + return + + saved_any = False + # Save each dataframe to SQLite database + self.debug(f"Saving to {output_file_path}") + for table, df in data.items(): + if len(df) > 0: + create_table_and_save_to_sql( + df, + table, + output_file_path, + default_type="FLOAT", + integer_primary_key=len(df) <= n_events, + ) + saved_any = True + + if saved_any: + self.debug("- Done saving") + else: + self.warning(f"No data saved to {output_file_path}") + + def merge_files( + self, + files: List[str], + output_dir: str, + ) -> None: + """SQLite-specific method for merging output files/databases. + + Args: + files: paths to SQLite databases that needs to be merged. + output_dir: path to store the merged database(s) in. + database_name: name, not path, of database. E.g. "my_database". + max_table_size: The maximum number of rows in any given table. + If given, the merging proceedure splits the databases into + partitions each with a maximum table size of max_table_size. + Note that the size is approximate. This feature is useful if + you have many events, as tables exceeding + 400 million rows tend to be noticably slower to query. + Defaults to None (All events are put into a single database.) + """ + # Warnings + if self._max_table_size: + self.warning( + f"Merging got max_table_size of {self._max_table_size}." + " Will attempt to create databases with a maximum row count of" + " this size." + ) + + # Set variables + self._partition_count = 1 + + # Construct full database path + database_path = os.path.join(output_dir, self._database_name) + print(database_path) + # Start merging if files are given + if len(files) > 0: + os.makedirs(output_dir, exist_ok=True) + self.info(f"Merging {len(files)} database files") + self._merge_databases(files=files, database_path=database_path) + else: + self.warning("No database files given! Exiting.") + + def _merge_databases( + self, + files: List[str], + database_path: str, + ) -> None: + """Merge the temporary databases. + + Args: + files: List of files to be merged. + database_path: Path to a database, can be an empty path, where the + databases listed in `files` will be merged into. If no database + exists at the given path, one will be created. + """ + if os.path.exists(database_path): + self.warning( + "Target path for merged database", + f"{database_path}, already exists.", + ) + + if self._max_table_size is not None: + database_path = self._adjust_output_path(database_path) + self._row_counts: Dict[str, int] = {} + self._largest_table = 0 + + # Merge temporary databases into newly created one + for file_count, input_file in tqdm(enumerate(files), colour="green"): + + # Extract table names and index column name in database + tables, primary_key = get_primary_keys(database=input_file) + + for table_name in tables.keys(): + # Extract all data in the table from the given database + df = query_database( + database=input_file, query=f"SELECT * FROM {table_name}" + ) + + # Infer whether the table was previously indexed with + # A primary key or not. len(tables[table]) = 0 if not. + integer_primary_key = ( + True if tables[table_name] is not None else False + ) + + # Submit to new database + create_table_and_save_to_sql( + df=df, + table_name=table_name, + database_path=database_path, + index_column=primary_key, + integer_primary_key=integer_primary_key, + default_type="FLOAT", + ) + + # Update row counts if needed + if self._max_table_size is not None: + self._update_row_counts(df=df, table_name=table_name) + + if (self._max_table_size is not None) & (file_count < len(files)): + assert self._max_table_size is not None # mypy... + if self._largest_table >= self._max_table_size: + # Increment partition, reset counts, adjust output path + self._partition_count += 1 + self._row_counts = {} + self._largest_table = 0 + database_path = self._adjust_output_path(database_path) + self.info( + "Maximum row count reached." + f" Creating new partition at {database_path}" + ) + + # Internal methods + + def _adjust_output_path(self, output_file: str) -> str: + """Adjust the file path to reflect that it is a partition.""" + path_without_extension, extension = os.path.splitext(output_file) + if "_part_" in path_without_extension: + # if true, this is already a partition. + database_name = path_without_extension.split("_part_")[:-1][0] + else: + database_name = path_without_extension + # split into multiple lines to avoid one long + database_name = database_name + f"_part_{self._partition_count}" + database_name = database_name + extension + return database_name + + def _update_row_counts(self, df: pd.DataFrame, table_name: str) -> None: + if table_name in self._row_counts.keys(): + self._row_counts[table_name] += len(df) + else: + self._row_counts[table_name] = len(df) + + self._largest_table = max(self._row_counts.values()) + return diff --git a/src/graphnet/deployment/__init__.py b/src/graphnet/deployment/__init__.py index 7c0c342d0..c26dfc1b7 100644 --- a/src/graphnet/deployment/__init__.py +++ b/src/graphnet/deployment/__init__.py @@ -3,3 +3,5 @@ `graphnet.deployment` allows for using trained models for inference in domain- specific reconstruction chains. """ +from .deployer import Deployer +from .deployment_module import DeploymentModule diff --git a/src/graphnet/deployment/deployer.py b/src/graphnet/deployment/deployer.py new file mode 100644 index 000000000..4ac67f4d8 --- /dev/null +++ b/src/graphnet/deployment/deployer.py @@ -0,0 +1,131 @@ +"""Contains the graphnet deployment module.""" +import random +from abc import abstractmethod, ABC +import multiprocessing +from typing import TYPE_CHECKING, List, Union, Sequence, Any +import time + +from graphnet.utilities.imports import has_torch_package +from .deployment_module import DeploymentModule +from graphnet.utilities.logging import Logger + +if has_torch_package or TYPE_CHECKING: + import torch + + +class Deployer(ABC, Logger): + """A generic baseclass for applying `DeploymentModules` to analysis files. + + Modules are applied in the order that they appear in `modules`. + """ + + @abstractmethod + def _process_files( + self, + settings: Any, + ) -> None: + """Process a single file. + + If n_workers > 1, this function is run in parallel n_worker times. Each + worker will loop over an allocated set of files. + """ + raise NotImplementedError + + @abstractmethod + def _prepare_settings( + self, input_files: List[str], output_folder: str + ) -> List[Any]: + """Produce a list of inputs for each worker. + + This function must produce and return a list of arguments to each + worker. + """ + raise NotImplementedError + + def __init__( + self, + modules: Union[DeploymentModule, Sequence[DeploymentModule]], + n_workers: int = 1, + ) -> None: + """Initialize `Deployer`. + + Will apply `DeploymentModules` to files in the order in which they + appear in `modules`. Each module is run independently. + + Args: + modules: List of `DeploymentModules`. + Order of appearence in the list determines order + of deployment. + n_workers: Number of workers. The deployer will divide the number + of input files across workers. Defaults to 1. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + # This makes sure that one worker cannot access more + # than 1 core's worth of compute. + + if torch.get_num_interop_threads() > 1: + torch.set_num_interop_threads(1) + if torch.get_num_threads() > 1: + torch.set_num_threads(1) + + # Check + if isinstance(modules, list): + self._modules = modules + else: + self._modules = [modules] + + # Member Variables + self._n_workers = n_workers + + def _launch_jobs(self, settings: List[Any]) -> None: + """Will launch jobs in parallel if n_workers > 1, else run on main.""" + if self._n_workers > 1: + processes = [] + for i in range(self._n_workers): + processes.append( + multiprocessing.Process( + target=self._process_files, + args=[settings[i]], # type: ignore + ) + ) + + for process in processes: + process.start() + + for process in processes: + process.join() + else: + self._process_files(settings[0]) + + def run( + self, + input_files: Union[List[str], str], + output_folder: str, + ) -> None: + """Apply `modules` to input files. + + Args: + input_files: Path(s) to i3 file(s) that you wish to + apply the graphnet modules to. + output_folder: The output folder to which the i3 files are written. + """ + start_time = time.time() + if isinstance(input_files, list): + random.shuffle(input_files) + else: + input_files = [input_files] + settings = self._prepare_settings( + input_files=input_files, output_folder=output_folder + ) + + assert len(settings) == self._n_workers + + self.info( + f"""processing {len(input_files)} files \n + using {self._n_workers} workers""" + ) + self._launch_jobs(settings) + self.info( + f"""Processing {len(input_files)} files was completed in \n + {time.time() - start_time} seconds using {self._n_workers} cores.""" + ) diff --git a/src/graphnet/deployment/deployment_module.py b/src/graphnet/deployment/deployment_module.py new file mode 100644 index 000000000..0083d5dce --- /dev/null +++ b/src/graphnet/deployment/deployment_module.py @@ -0,0 +1,96 @@ +"""Class(es) for deploying GraphNeT models in icetray as I3Modules.""" +from abc import abstractmethod +from typing import Any, List, Union, Dict + +import numpy as np +from torch import Tensor +from torch_geometric.data import Data, Batch + +from graphnet.models import Model +from graphnet.utilities.config import ModelConfig +from graphnet.utilities.logging import Logger + + +class DeploymentModule(Logger): + """Base DeploymentModule for GraphNeT. + + Contains standard methods for loading models doing inference with them. + Experiment-specific implementations may overwrite methods and should define + `__call__`. + """ + + def __init__( + self, + model_config: Union[ModelConfig, str], + state_dict: Union[Dict[str, Tensor], str], + device: str = "cpu", + prediction_columns: Union[List[str], None] = None, + ): + """Construct DeploymentModule. + + Arguments: + model_config: A model configuration file. + state_dict: A state dict for the model. + device: The computational device to use. Defaults to "cpu". + prediction_columns: Column names for each column in model output. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + # Set Member Variables + self.model = self._load_model( + model_config=model_config, state_dict=state_dict + ) + + self.prediction_columns = self._resolve_prediction_columns( + prediction_columns + ) + + # Set model to inference mode. + self.model.inference() + + # Move model to device + self.model.to(device) + + @abstractmethod + def __call__(self, input_data: Any) -> Any: + """Define here how the module acts on a file/data stream.""" + + def _load_model( + self, + model_config: Union[ModelConfig, str], + state_dict: Union[Dict[str, Tensor], str], + ) -> Model: + """Load `Model` from config and insert learned weights.""" + model = Model.from_config(model_config, trust=True) + model.load_state_dict(state_dict) + return model + + def _resolve_prediction_columns( + self, prediction_columns: Union[List[str], None] + ) -> List[str]: + if prediction_columns is not None: + if isinstance(prediction_columns, str): + prediction_columns = [prediction_columns] + else: + prediction_columns = prediction_columns + else: + prediction_columns = self.model.prediction_labels + return prediction_columns + + def _inference(self, data: Union[Data, Batch]) -> List[np.ndarray]: + """Apply model to a single event or batch of events `data`. + + Args: + data: A `Data` or ``Batch` object - + either a single output of a `GraphDefinition` or a batch of + them. + + Returns: + A List of numpy arrays, each representing the output from the + `Task`s that the model contains. + """ + # Perform inference + output = self.model(data=data) + # Loop over tasks in model and transform to numpy + for k in range(len(output)): + output[k] = output[k].detach().numpy() + return output diff --git a/src/graphnet/deployment/i3modules/__init__.py b/src/graphnet/deployment/i3modules/__init__.py index de4fcca7d..e2fd05a43 100644 --- a/src/graphnet/deployment/i3modules/__init__.py +++ b/src/graphnet/deployment/i3modules/__init__.py @@ -5,5 +5,5 @@ detector configurations. """ -from .graphnet_module import * -from .deployer import * +from .deprecated_methods import * +from graphnet.deployment.icecube import I3InferenceModule, I3PulseCleanerModule diff --git a/src/graphnet/deployment/i3modules/deployer.py b/src/graphnet/deployment/i3modules/deployer.py deleted file mode 100644 index 2228005f8..000000000 --- a/src/graphnet/deployment/i3modules/deployer.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Contains the graphnet i3 deployment module.""" -import os.path -import os -import random -import multiprocessing -from typing import TYPE_CHECKING, List, Union, Sequence -import time -import numpy as np -from dataclasses import dataclass - -from graphnet.utilities.imports import has_icecube_package, has_torch_package -from graphnet.deployment.i3modules import ( - GraphNeTI3Module, -) - -if has_icecube_package() or TYPE_CHECKING: - from icecube import icetray, dataio # pyright: reportMissingImports=false - from I3Tray import I3Tray - -if has_torch_package or TYPE_CHECKING: - import torch - - -@dataclass -class Settings: - """Dataclass for workers in GraphNeTI3Deployer.""" - - i3_files: List[str] - gcd_file: str - output_folder: str - modules: List[GraphNeTI3Module] - - -class GraphNeTI3Deployer: - """Deploys graphnet i3 modules to i3 files. - - Modules are applied in the order in which they appear in graphnet_modules. - """ - - def __init__( - self, - graphnet_modules: Union[GraphNeTI3Module, Sequence[GraphNeTI3Module]], - gcd_file: str, - n_workers: int = 1, - ) -> None: - """Initialize the deployer. - - Will apply graphnet i3 modules to i3 files in the order in which they - appear in graphnet_modules.Each module is run independently. - - Args: - graphnet_modules: List of graphnet i3 modules. - Order of appearence in the list determines order - of deployment. - gcd_file: path to gcd file. - n_workers: Number of workers. The deployer will divide the number - of input files across workers. Defaults to 1. - """ - # This makes sure that one worker cannot access more - # than 1 core's worth of compute. - - if torch.get_num_interop_threads() > 1: - torch.set_num_interop_threads(1) - if torch.get_num_threads() > 1: - torch.set_num_threads(1) - # Check - if isinstance(graphnet_modules, list): - self._modules = graphnet_modules - else: - self._modules = [graphnet_modules] - self._gcd_file = gcd_file - self._n_workers = n_workers - - def _prepare_settings( - self, input_files: List[str], output_folder: str - ) -> List[Settings]: - """Will prepare the settings for each worker.""" - try: - os.makedirs(output_folder) - except FileExistsError: - assert False, f"""{output_folder} already exists. To avoid overwriting \n - existing files, the process has been stopped.""" - if self._n_workers > len(input_files): - self._n_workers = len(input_files) - if self._n_workers > 1: - file_batches = np.array_split(input_files, self._n_workers) - settings: List[Settings] = [] - for i in range(self._n_workers): - settings.append( - Settings( - file_batches[i], - self._gcd_file, - output_folder, - self._modules, - ) - ) - else: - settings = [ - Settings( - input_files, - self._gcd_file, - output_folder, - self._modules, - ) - ] - return settings - - def _launch_jobs(self, settings: List[Settings]) -> None: - """Will launch jobs in parallel if n_workers > 1, else run on main.""" - if self._n_workers > 1: - processes = [] - for i in range(self._n_workers): - processes.append( - multiprocessing.Process( - target=self._process_files, - args=[settings[i]], # type: ignore - ) - ) - - for process in processes: - process.start() - - for process in processes: - process.join() - else: - self._process_files(settings[0]) - - def _process_files( - self, - settings: Settings, - ) -> None: - """Will start an IceTray read/write chain with graphnet modules. - - If n_workers > 1, this function is run in parallel n_worker times. Each - worker will loop over an allocated set of i3 files. The new i3 files - will appear as copies of the original i3 files but with reconstructions - added. Original i3 files are left untouched. - """ - for i3_file in settings.i3_files: - tray = I3Tray() - tray.context["I3FileStager"] = dataio.get_stagers() - tray.AddModule( - "I3Reader", - "reader", - FilenameList=[settings.gcd_file, i3_file], - ) - for i3_module in settings.modules: - tray.AddModule(i3_module) - tray.Add( - "I3Writer", - Streams=[ - icetray.I3Frame.DAQ, - icetray.I3Frame.Physics, - icetray.I3Frame.TrayInfo, - icetray.I3Frame.Simulation, - ], - filename=settings.output_folder + "/" + i3_file.split("/")[-1], - ) - tray.Execute() - tray.Finish() - return - - def run( - self, - input_files: Union[List[str], str], - output_folder: str, - ) -> None: - """Apply given graphnet modules to input files using n workers. - - The i3 files with reconstructions will appear as copies of - the original i3 files but with reconstructions added. - Original i3 files are left untouched. - - Args: - input_files: Path(s) to i3 file(s) that you wish to - apply the graphnet modules to. - output_folder: The output folder to which the i3 files are written. - """ - start_time = time.time() - if isinstance(input_files, list): - random.shuffle(input_files) - else: - input_files = [input_files] - settings = self._prepare_settings( - input_files=input_files, output_folder=output_folder - ) - print( - f"""processing {len(input_files)} i3 files \n - using {self._n_workers} workers""" - ) - self._launch_jobs(settings) - print( - f"""Processing {len(input_files)} files was completed in \n - {time.time() - start_time} seconds using {self._n_workers} cores.""" - ) diff --git a/src/graphnet/deployment/i3modules/deprecated_methods.py b/src/graphnet/deployment/i3modules/deprecated_methods.py new file mode 100644 index 000000000..6acdc8d33 --- /dev/null +++ b/src/graphnet/deployment/i3modules/deprecated_methods.py @@ -0,0 +1,43 @@ +"""Contains deprecated methods.""" +from typing import Union, Sequence + +# from graphnet.deployment.icecube import I3Deployer, I3InferenceModule +from ..icecube.i3deployer import I3Deployer +from ..icecube.inference_module import I3InferenceModule + + +class GraphNeTI3Deployer(I3Deployer): + """Class has been renamed to `I3Deployer`. + + Please use `I3Deployer` instead. + """ + + def __init__( + self, + graphnet_modules: Union[ + I3InferenceModule, Sequence[I3InferenceModule] + ], + gcd_file: str, + n_workers: int = 1, + ) -> None: + """Initialize `GraphNeTI3Deployer`. + + Will apply `DeploymentModules` to files in the order in which they + appear in `modules`. Each module is run independently. + + Args: + graphnet_modules: List of `DeploymentModules`. + Order of appearence in the list determines order + of deployment. + gcd_file: path to gcd file. + n_workers: Number of workers. The deployer will divide the number + of input files across workers. Defaults to 1. + """ + super().__init__( + modules=graphnet_modules, n_workers=n_workers, gcd_file=gcd_file + ) + self.warning( + f"{self.__class__} will be deprecated in GraphNeT 2.0" + " Please use `I3Deployer` instead. " + " E.g.: `from graphnet.deployment.icecube import I3Deployer`" + ) diff --git a/src/graphnet/deployment/i3modules/graphnet_module.py b/src/graphnet/deployment/i3modules/graphnet_module.py deleted file mode 100644 index d3aa878e0..000000000 --- a/src/graphnet/deployment/i3modules/graphnet_module.py +++ /dev/null @@ -1,453 +0,0 @@ -"""Class(es) for deploying GraphNeT models in icetray as I3Modules.""" -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, List, Union, Dict, Tuple, Optional - -import dill -import numpy as np -import torch -from torch_geometric.data import Data, Batch - -from graphnet.data.extractors import ( - I3FeatureExtractor, - I3FeatureExtractorIceCubeUpgrade, -) -from graphnet.models import Model, StandardModel -from graphnet.models.graphs import GraphDefinition -from graphnet.utilities.imports import has_icecube_package -from graphnet.utilities.config import ModelConfig -from graphnet.utilities.logging import Logger - -if has_icecube_package() or TYPE_CHECKING: - from icecube.icetray import ( - I3Module, - I3Frame, - ) # pyright: reportMissingImports=false - from icecube.dataclasses import ( - I3Double, - I3MapKeyVectorDouble, - ) # pyright: reportMissingImports=false - from icecube import dataclasses, dataio, icetray - - -class GraphNeTI3Module(Logger): - """Base I3 Module for GraphNeT. - - Contains methods for extracting pulsemaps, producing graphs and writing to - frames. - """ - - def __init__( - self, - graph_definition: GraphDefinition, - pulsemap: str, - features: List[str], - pulsemap_extractor: Union[ - List[I3FeatureExtractor], I3FeatureExtractor - ], - gcd_file: str, - ): - """I3Module Constructor. - - Arguments: - graph_definition: An instance of GraphDefinition. E.g. KNNGraph. - pulsemap: the pulse map on which the module functions - features: the features that is used from the pulse map. - E.g. [dom_x, dom_y, dom_z, charge] - pulsemap_extractor: The I3FeatureExtractor used to extract the - pulsemap from the I3Frames - gcd_file: Path to the associated gcd-file. - """ - super().__init__(name=__name__, class_name=self.__class__.__name__) - assert isinstance(graph_definition, GraphDefinition) - self._graph_definition = graph_definition - self._pulsemap = pulsemap - self._features = features - assert isinstance(gcd_file, str), "gcd_file must be string" - self._gcd_file = gcd_file - if isinstance(pulsemap_extractor, list): - self._i3_extractors = pulsemap_extractor - else: - self._i3_extractors = [pulsemap_extractor] - - for i3_extractor in self._i3_extractors: - i3_extractor.set_files(i3_file="", gcd_file=self._gcd_file) - - @abstractmethod - def __call__(self, frame: I3Frame) -> bool: - """Define here how the module acts on the frame. - - Must return True if successful. - - Return True # SUPER IMPORTANT - """ - - def _make_graph( - self, frame: I3Frame - ) -> Data: # py-l-i-n-t-:- -d-i-s-able=invalid-name - """Process Physics I3Frame into graph.""" - # Extract features - input_features = self._extract_feature_array_from_frame(frame) - # Prepare graph data - if len(input_features) > 0: - data = self._graph_definition( - input_features=input_features, - input_feature_names=self._features, - ) - return Batch.from_data_list([data]) - else: - return None - - def _extract_feature_array_from_frame(self, frame: I3Frame) -> np.array: - """Apply the I3FeatureExtractors to the I3Frame. - - Arguments: - frame: Physics I3Frame (PFrame) - - Returns: - array with pulsemap - """ - features = None - for i3extractor in self._i3_extractors: - feature_dict = i3extractor(frame) - features_pulsemap = np.array( - [feature_dict[key] for key in self._features] - ).T - if features is None: - features = features_pulsemap - else: - features = np.concatenate( - (features, features_pulsemap), axis=0 - ) - return features - - def _add_to_frame(self, frame: I3Frame, data: Dict[str, Any]) -> I3Frame: - """Add every field in data to I3Frame. - - Arguments: - frame: I3Frame (physics) - data: Dictionary containing content that will be written to frame. - - Returns: - frame: Same I3Frame as input, but with the new entries - """ - assert isinstance( - data, dict - ), f"data must be of type dict. Got {type(data)}" - for key in data.keys(): - if key not in frame: - frame.Put(key, data[key]) - return frame - - -class I3InferenceModule(GraphNeTI3Module): - """General class for inference on i3 frames.""" - - def __init__( - self, - pulsemap: str, - features: List[str], - pulsemap_extractor: Union[ - List[I3FeatureExtractor], I3FeatureExtractor - ], - model_config: Union[ModelConfig, str], - state_dict: str, - model_name: str, - gcd_file: str, - prediction_columns: Optional[Union[List[str], str]] = None, - ): - """General class for inference on I3Frames (physics). - - Arguments: - pulsemap: the pulsmap that the model is expecting as input. - features: the features of the pulsemap that the model is expecting. - pulsemap_extractor: The extractor used to extract the pulsemap. - model_config: The ModelConfig (or path to it) that summarizes the - model used for inference. - state_dict: Path to state_dict containing the learned weights. - model_name: The name used for the model. Will help define the - named entry in the I3Frame. E.g. "dynedge". - gcd_file: path to associated gcd file. - prediction_columns: column names for the predictions of the model. - Will help define the named entry in the I3Frame. - E.g. ['energy_reco']. Optional. - """ - # Construct model & load weights - self.model = Model.from_config(model_config, trust=True) - self.model.load_state_dict(state_dict) - - super().__init__( - pulsemap=pulsemap, - features=features, - pulsemap_extractor=pulsemap_extractor, - gcd_file=gcd_file, - graph_definition=self.model._graph_definition, - ) - self.model.inference() - - self.model.to("cpu") - if prediction_columns is not None: - if isinstance(prediction_columns, str): - self.prediction_columns = [prediction_columns] - else: - self.prediction_columns = prediction_columns - else: - self.prediction_columns = self.model.prediction_labels - - self.model_name = model_name - - def __call__(self, frame: I3Frame) -> bool: - """Write predictions from model to frame.""" - # inference - graph = self._make_graph(frame) - if graph is not None: - predictions = self._inference(graph) - else: - self.warning( - f"At least one event has no pulses in {self._pulsemap} - padding {self.prediction_columns} with NaN." - ) - predictions = np.repeat( - [np.nan], len(self.prediction_columns) - ).reshape(-1, len(self.prediction_columns)) - # Check dimensions of predictions and prediction columns - if len(predictions.shape) > 1: - dim = predictions.shape[1] - else: - dim = len(predictions) - assert dim == len( - self.prediction_columns - ), f"""predictions have shape {dim} but \n - prediction columns have [{self.prediction_columns}]""" - - # Build Dictionary of predictions - data = {} - assert predictions.shape[0] == 1 - for i in range(dim if isinstance(dim, int) else len(dim)): - try: - assert len(predictions[:, i]) == 1 - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(float(predictions[:, i][0])) - except IndexError: - data[ - self.model_name + "_" + self.prediction_columns[i] - ] = I3Double(predictions[0]) - - # Submission methods - frame = self._add_to_frame(frame=frame, data=data) - return True - - def _inference(self, data: Data) -> np.ndarray: - # Perform inference - task_predictions = self.model(data) - assert ( - len(task_predictions) == 1 - ), f"""This method assumes a single task. \n - Got {len(task_predictions)} tasks.""" - return self.model(data)[0].detach().numpy() - - -class I3PulseCleanerModule(I3InferenceModule): - """A specialized module for pulse cleaning. - - It is assumed that the model provided has been trained for this. - """ - - def __init__( - self, - pulsemap: str, - features: List[str], - pulsemap_extractor: Union[ - List[I3FeatureExtractor], I3FeatureExtractor - ], - model_config: str, - state_dict: str, - model_name: str, - *, - gcd_file: str, - threshold: float = 0.7, - discard_empty_events: bool = False, - prediction_columns: Optional[Union[List[str], str]] = None, - ): - """General class for inference on I3Frames (physics). - - Arguments: - pulsemap: the pulsmap that the model is expecting as input - (the one that is being cleaned). - features: the features of the pulsemap that the model is expecting. - pulsemap_extractor: The extractor used to extract the pulsemap. - model_config: The ModelConfig (or path to it) that summarizes the - model used for inference. - state_dict: Path to state_dict containing the learned weights. - model_name: The name used for the model. Will help define the named - entry in the I3Frame. E.g. "dynedge". - gcd_file: path to associated gcd file. - threshold: the threshold for being considered a positive case. - E.g., predictions >= threshold will be considered - to be signal, all else noise. - discard_empty_events: When true, this flag will eliminate events - whose cleaned pulse series are empty. Can be used - to speed up processing especially for noise - simulation, since it will not do any writing or - further calculations. - prediction_columns: column names for the predictions of the model. - Will help define the named entry in the I3Frame. - E.g. ['energy_reco']. Optional. - """ - super().__init__( - pulsemap=pulsemap, - features=features, - pulsemap_extractor=pulsemap_extractor, - model_config=model_config, - state_dict=state_dict, - model_name=model_name, - prediction_columns=prediction_columns, - gcd_file=gcd_file, - ) - self._threshold = threshold - self._predictions_key = f"{pulsemap}_{model_name}_Predictions" - self._total_pulsemap_name = f"{pulsemap}_{model_name}_Pulses" - self._discard_empty_events = discard_empty_events - - def __call__(self, frame: I3Frame) -> bool: - """Add a cleaned pulsemap to frame.""" - # inference - gcd_file = self._gcd_file - graph = self._make_graph(frame) - if graph is None: # If there is no pulses to clean - return False - predictions = self._inference(graph) - if self._discard_empty_events: - if sum(predictions > self._threshold) == 0: - return False - - if len(predictions.shape) == 1: - predictions = predictions.reshape(-1, 1) - - assert predictions.shape[1] == 1 - - # Build Dictionary of predictions - data = {} - - predictions_map = self._construct_prediction_map( - frame=frame, predictions=predictions - ) - - # Adds the raw predictions to dictionary - if self._predictions_key not in frame.keys(): - data[self._predictions_key] = predictions_map - - # Create a pulse map mask, indicating the pulses that are over - # threshold (e.g. identified as signal) and therefore should be kept - # Using a lambda function to evaluate which pulses to keep by - # checking the prediction for each pulse - # (Adds the actual pulsemap to dictionary) - if self._total_pulsemap_name not in frame.keys(): - data[ - self._total_pulsemap_name - ] = dataclasses.I3RecoPulseSeriesMapMask( - frame, - self._pulsemap, - lambda om_key, index, pulse: predictions_map[om_key][index] - >= self._threshold, - ) - - # Submit predictions and general pulsemap - frame = self._add_to_frame(frame=frame, data=data) - data = {} - # Adds an additional pulsemap for each DOM type - if isinstance( - self._i3_extractors[0], I3FeatureExtractorIceCubeUpgrade - ): - mDOMMap, DEggMap, IceCubeMap = self._split_pulsemap_in_dom_types( - frame=frame, gcd_file=gcd_file - ) - - if f"{self._total_pulsemap_name}_mDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_mDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(mDOMMap) - - if f"{self._total_pulsemap_name}_dEggs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_dEggs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(DEggMap) - - if f"{self._total_pulsemap_name}_pDOMs_Only" not in frame.keys(): - data[ - f"{self._total_pulsemap_name}_pDOMs_Only" - ] = dataclasses.I3RecoPulseSeriesMap(IceCubeMap) - - # Submits the additional pulsemaps to the frame - frame = self._add_to_frame(frame=frame, data=data) - - return True - - def _split_pulsemap_in_dom_types( - self, frame: I3Frame, gcd_file: Any - ) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: - """Will split the cleaned pulsemap into multiple pulsemaps. - - Arguments: - frame: I3Frame (physics) - gcd_file: path to associated gcd file - - Returns: - mDOMMap, DeGGMap, IceCubeMap - """ - g = dataio.I3File(gcd_file) - gFrame = g.pop_frame() - while "I3Geometry" not in gFrame.keys(): - gFrame = g.pop_frame() - omGeoMap = gFrame["I3Geometry"].omgeo - - mDOMMap, DEggMap, IceCubeMap = {}, {}, {} - pulses = dataclasses.I3RecoPulseSeriesMap.from_frame( - frame, self._total_pulsemap_name - ) - for P in pulses: - om = omGeoMap[P[0]] - if om.omtype == 130: # "mDOM" - mDOMMap[P[0]] = P[1] - elif om.omtype == 120: # "DEgg" - DEggMap[P[0]] = P[1] - elif om.omtype == 20: # "IceCube / pDOM" - IceCubeMap[P[0]] = P[1] - return mDOMMap, DEggMap, IceCubeMap - - def _construct_prediction_map( - self, frame: I3Frame, predictions: np.ndarray - ) -> I3MapKeyVectorDouble: - """Make a pulsemap from predictions (for all OM types). - - Arguments: - frame: I3Frame (physics) - predictions: predictions from Model. - - Returns: - predictions_map: a pulsemap from predictions - """ - pulsemap = dataclasses.I3RecoPulseSeriesMap.from_frame( - frame, self._pulsemap - ) - - idx = 0 - predictions = predictions.squeeze(1) - predictions_map = dataclasses.I3MapKeyVectorDouble() - for om_key, pulses in pulsemap.items(): - num_pulses = len(pulses) - predictions_map[om_key] = predictions[ - idx : idx + num_pulses - ].tolist() - idx += num_pulses - - # Checks - assert idx == len( - predictions - ), """Not all predictions were mapped to pulses,\n - validation of predictions have failed.""" - - assert ( - pulsemap.keys() == predictions_map.keys() - ), """Input pulse map and predictions map do \n - not contain exactly the same OMs""" - return predictions_map diff --git a/src/graphnet/deployment/icecube/__init__.py b/src/graphnet/deployment/icecube/__init__.py new file mode 100644 index 000000000..15c7485ef --- /dev/null +++ b/src/graphnet/deployment/icecube/__init__.py @@ -0,0 +1,4 @@ +"""Deployment modules specific to IceCube.""" +from .inference_module import I3InferenceModule +from .cleaning_module import I3PulseCleanerModule +from .i3deployer import I3Deployer diff --git a/src/graphnet/deployment/icecube/cleaning_module.py b/src/graphnet/deployment/icecube/cleaning_module.py new file mode 100644 index 000000000..0a2417331 --- /dev/null +++ b/src/graphnet/deployment/icecube/cleaning_module.py @@ -0,0 +1,228 @@ +"""IceCube I3InferenceModule. + +Contains functionality for writing model predictions to i3 files. +""" +from typing import List, Union, Optional, TYPE_CHECKING, Dict, Any, Tuple + +import numpy as np + +from .inference_module import I3InferenceModule +from graphnet.utilities.config import ModelConfig +from graphnet.utilities.imports import has_icecube_package +from graphnet.data.extractors.icecube import ( + I3FeatureExtractor, + I3FeatureExtractorIceCubeUpgrade, +) + +if has_icecube_package() or TYPE_CHECKING: + from icecube.icetray import ( + I3Frame, + ) # pyright: reportMissingImports=false + from icecube.dataclasses import ( + I3MapKeyVectorDouble, + ) # pyright: reportMissingImports=false + from icecube import dataclasses, dataio + + +class I3PulseCleanerModule(I3InferenceModule): + """A specialized module for pulse cleaning. + + It is assumed that the model provided has been trained for this. + """ + + def __init__( + self, + pulsemap: str, + features: List[str], + pulsemap_extractor: Union[ + List[I3FeatureExtractor], I3FeatureExtractor + ], + model_config: Union[ModelConfig, str], + state_dict: str, + model_name: str, + *, + gcd_file: str, + threshold: float = 0.7, + discard_empty_events: bool = False, + ): + """General class for inference on I3Frames (physics). + + Arguments: + pulsemap: the pulsmap that the model is expecting as input + (the one that is being cleaned). + features: the features of the pulsemap that the model is expecting. + pulsemap_extractor: The extractor used to extract the pulsemap. + model_config: The ModelConfig (or path to it) that summarizes the + model used for inference. + state_dict: Path to state_dict containing the learned weights. + model_name: The name used for the model. Will help define the named + entry in the I3Frame. E.g. "dynedge". + gcd_file: path to associated gcd file. + threshold: the threshold for being considered a positive case. + E.g., predictions >= threshold will be considered + to be signal, all else noise. + discard_empty_events: When true, this flag will eliminate events + whose cleaned pulse series are empty. Can be used + to speed up processing especially for noise + simulation, since it will not do any writing or + further calculations. + """ + super().__init__( + pulsemap=pulsemap, + features=features, + pulsemap_extractor=pulsemap_extractor, + model_config=model_config, + state_dict=state_dict, + model_name=model_name, + gcd_file=gcd_file, + ) + self._threshold = threshold + self._predictions_key = f"{pulsemap}_{model_name}_Predictions" + self._total_pulsemap_name = f"{pulsemap}_{model_name}_Pulses" + self._discard_empty_events = discard_empty_events + + def __call__(self, frame: I3Frame) -> bool: + """Add a cleaned pulsemap to frame.""" + # inference + gcd_file = self._gcd_file + data = self._create_data_representation(frame) + if data is None: # If there is no pulses to clean + return False + predictions = self._inference(data)[0] + + if self._discard_empty_events: + if sum(predictions > self._threshold) == 0: + return False + + if len(predictions.shape) == 1: + predictions = predictions.reshape(-1, 1) + + assert predictions.shape[1] == 1 + + del data # memory + # Build Dictionary of predictions + data_dict = {} + + predictions_map = self._construct_prediction_map( + frame=frame, predictions=predictions + ) + + # Adds the raw predictions to dictionary + if self._predictions_key not in frame.keys(): + data_dict[self._predictions_key] = predictions_map + + # Create a pulse map mask, indicating the pulses that are over + # threshold (e.g. identified as signal) and therefore should be kept + # Using a lambda function to evaluate which pulses to keep by + # checking the prediction for each pulse + # (Adds the actual pulsemap to dictionary) + if self._total_pulsemap_name not in frame.keys(): + data_dict[ + self._total_pulsemap_name + ] = dataclasses.I3RecoPulseSeriesMapMask( + frame, + self._pulsemap, + lambda om_key, index, pulse: predictions_map[om_key][index] + >= self._threshold, + ) + + # Submit predictions and general pulsemap + frame = self._add_to_frame(frame=frame, data=data_dict) + data = {} + # Adds an additional pulsemap for each DOM type + if isinstance( + self._i3_extractors[0], I3FeatureExtractorIceCubeUpgrade + ): + mDOMMap, DEggMap, IceCubeMap = self._split_pulsemap_in_dom_types( + frame=frame, gcd_file=gcd_file + ) + + if f"{self._total_pulsemap_name}_mDOMs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_mDOMs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(mDOMMap) + + if f"{self._total_pulsemap_name}_dEggs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_dEggs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(DEggMap) + + if f"{self._total_pulsemap_name}_pDOMs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_pDOMs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + + # Submits the additional pulsemaps to the frame + frame = self._add_to_frame(frame=frame, data=data) + + return True + + def _split_pulsemap_in_dom_types( + self, frame: I3Frame, gcd_file: Any + ) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: + """Will split the cleaned pulsemap into multiple pulsemaps. + + Arguments: + frame: I3Frame (physics) + gcd_file: path to associated gcd file + + Returns: + mDOMMap, DeGGMap, IceCubeMap + """ + g = dataio.I3File(gcd_file) + gFrame = g.pop_frame() + while "I3Geometry" not in gFrame.keys(): + gFrame = g.pop_frame() + omGeoMap = gFrame["I3Geometry"].omgeo + + mDOMMap, DEggMap, IceCubeMap = {}, {}, {} + pulses = dataclasses.I3RecoPulseSeriesMap.from_frame( + frame, self._total_pulsemap_name + ) + for P in pulses: + om = omGeoMap[P[0]] + if om.omtype == 130: # "mDOM" + mDOMMap[P[0]] = P[1] + elif om.omtype == 120: # "DEgg" + DEggMap[P[0]] = P[1] + elif om.omtype == 20: # "IceCube / pDOM" + IceCubeMap[P[0]] = P[1] + return mDOMMap, DEggMap, IceCubeMap + + def _construct_prediction_map( + self, frame: I3Frame, predictions: np.ndarray + ) -> I3MapKeyVectorDouble: + """Make a pulsemap from predictions (for all OM types). + + Arguments: + frame: I3Frame (physics) + predictions: predictions from Model. + + Returns: + predictions_map: a pulsemap from predictions + """ + pulsemap = dataclasses.I3RecoPulseSeriesMap.from_frame( + frame, self._pulsemap + ) + + idx = 0 + predictions = predictions.squeeze(1) + predictions_map = dataclasses.I3MapKeyVectorDouble() + for om_key, pulses in pulsemap.items(): + num_pulses = len(pulses) + predictions_map[om_key] = predictions[ + idx : idx + num_pulses + ].tolist() + idx += num_pulses + + # Checks + assert idx == len( + predictions + ), """Not all predictions were mapped to pulses,\n + validation of predictions have failed.""" + + assert ( + pulsemap.keys() == predictions_map.keys() + ), """Input pulse map and predictions map do \n + not contain exactly the same OMs""" + return predictions_map diff --git a/src/graphnet/deployment/icecube/i3deployer.py b/src/graphnet/deployment/icecube/i3deployer.py new file mode 100644 index 000000000..de733b5c4 --- /dev/null +++ b/src/graphnet/deployment/icecube/i3deployer.py @@ -0,0 +1,117 @@ +"""Contains an IceCube-specific implementation of Deployer.""" + +from typing import TYPE_CHECKING, List, Union, Sequence +import os +import numpy as np + +from graphnet.utilities.imports import has_icecube_package +from graphnet.deployment.icecube import I3InferenceModule +from graphnet.data.dataclasses import Settings +from graphnet.deployment import Deployer + +if has_icecube_package() or TYPE_CHECKING: + from icecube import icetray, dataio # pyright: reportMissingImports=false + from I3Tray import I3Tray + + +class I3Deployer(Deployer): + """A generic baseclass for applying `DeploymentModules` to analysis files. + + Modules are applied in the order that they appear in `modules`. + """ + + def __init__( + self, + modules: Union[I3InferenceModule, Sequence[I3InferenceModule]], + gcd_file: str, + n_workers: int = 1, + ) -> None: + """Initialize `Deployer`. + + Will apply `DeploymentModules` to files in the order in which they + appear in `modules`. Each module is run independently. + + Args: + modules: List of `DeploymentModules`. + Order of appearence in the list determines order + of deployment. + gcd_file: path to gcd file. + n_workers: Number of workers. The deployer will divide the number + of input files across workers. Defaults to 1. + """ + super().__init__(modules=modules, n_workers=n_workers) + + # Member variables + self._gcd_file = gcd_file + + def _process_files( + self, + settings: Settings, + ) -> None: + """Will start an IceTray read/write chain with graphnet modules. + + If n_workers > 1, this function is run in parallel n_worker times. Each + worker will loop over an allocated set of i3 files. The new i3 files + will appear as copies of the original i3 files but with reconstructions + added. Original i3 files are left untouched. + """ + for i3_file in settings.i3_files: + tray = I3Tray() + tray.context["I3FileStager"] = dataio.get_stagers() + tray.AddModule( + "I3Reader", + "reader", + FilenameList=[settings.gcd_file, i3_file], + ) + for i3_module in settings.modules: + tray.AddModule(i3_module) + tray.Add( + "I3Writer", + Streams=[ + icetray.I3Frame.DAQ, + icetray.I3Frame.Physics, + icetray.I3Frame.TrayInfo, + icetray.I3Frame.Simulation, + ], + filename=settings.output_folder + "/" + i3_file.split("/")[-1], + ) + tray.Execute() + tray.Finish() + return + + def _prepare_settings( + self, input_files: List[str], output_folder: str + ) -> List[Settings]: + """Will prepare the settings for each worker.""" + try: + os.makedirs(output_folder) + except FileExistsError as e: + self.error( + f"{output_folder} already exists. To avoid overwriting " + "existing files, the process has been stopped." + ) + raise e + if self._n_workers > len(input_files): + self._n_workers = len(input_files) + if self._n_workers > 1: + file_batches = np.array_split(input_files, self._n_workers) + settings: List[Settings] = [] + for i in range(self._n_workers): + settings.append( + Settings( + file_batches[i], + self._gcd_file, + output_folder, + self._modules, + ) + ) + else: + settings = [ + Settings( + input_files, + self._gcd_file, + output_folder, + self._modules, + ) + ] + return settings diff --git a/src/graphnet/deployment/icecube/inference_module.py b/src/graphnet/deployment/icecube/inference_module.py new file mode 100644 index 000000000..9631cc3e0 --- /dev/null +++ b/src/graphnet/deployment/icecube/inference_module.py @@ -0,0 +1,205 @@ +"""IceCube I3InferenceModule. + +Contains functionality for writing model predictions to i3 files. +""" +from typing import List, Union, Optional, TYPE_CHECKING, Dict, Any + +import numpy as np +from torch_geometric.data import Data, Batch + +from graphnet.utilities.config import ModelConfig +from graphnet.deployment import DeploymentModule +from graphnet.data.extractors.icecube import I3FeatureExtractor +from graphnet.utilities.imports import has_icecube_package + +if has_icecube_package() or TYPE_CHECKING: + from icecube.icetray import ( + I3Frame, + ) # pyright: reportMissingImports=false + from icecube.dataclasses import ( + I3Double, + ) # pyright: reportMissingImports=false + + +class I3InferenceModule(DeploymentModule): + """General class for inference on i3 frames.""" + + def __init__( + self, + pulsemap_extractor: Union[ + List[I3FeatureExtractor], I3FeatureExtractor + ], + model_config: Union[ModelConfig, str], + state_dict: str, + model_name: str, + gcd_file: str, + features: Optional[List[str]] = None, + prediction_columns: Optional[Union[List[str], None]] = None, + pulsemap: Optional[str] = None, + ): + """General class for inference on I3Frames (physics). + + Arguments: + pulsemap_extractor: The extractor used to extract the pulsemap. + model_config: The ModelConfig (or path to it) that summarizes the + model used for inference. + state_dict: Path to state_dict containing the learned weights. + model_name: The name used for the model. Will help define the + named entry in the I3Frame. E.g. "dynedge". + gcd_file: path to associated gcd file. + features: the features of the pulsemap that the model is expecting. + prediction_columns: column names for the predictions of the model. + Will help define the named entry in the I3Frame. + E.g. ['energy_reco']. Optional. + pulsemap: the pulsmap that the model is expecting as input. + """ + super().__init__( + model_config=model_config, + state_dict=state_dict, + prediction_columns=prediction_columns, + ) + # Checks + assert isinstance(gcd_file, str), "gcd_file must be string" + + # Set Member Variables + if isinstance(pulsemap_extractor, list): + self._i3_extractors = pulsemap_extractor + else: + self._i3_extractors = [pulsemap_extractor] + if features is None: + features = self.model._graph_definition._input_feature_names + self._graph_definition = self.model._graph_definition + self._pulsemap = pulsemap + self._gcd_file = gcd_file + self.model_name = model_name + self._features = features + + # Set GCD file for pulsemap extractor + for i3_extractor in self._i3_extractors: + i3_extractor.set_gcd(i3_file="", gcd_file=self._gcd_file) + + def __call__(self, frame: I3Frame) -> bool: + """Write predictions from model to frame.""" + # inference + data = self._create_data_representation(frame=frame) + predictions = self._apply_model(data=data) + + # Check dimensions of predictions and prediction columns + dim = self._check_dimensions(predictions=predictions) + + # Build Dictionary from predictions + data = self._create_dictionary(dim=dim, predictions=predictions) + + # Submit Dictionary to frame + frame = self._add_to_frame(frame=frame, data=data) + return True + + def _check_dimensions(self, predictions: np.ndarray) -> int: + if len(predictions.shape) > 1: + dim = predictions.shape[1] + else: + dim = len(predictions) + try: + assert dim == len(self.prediction_columns) + except AssertionError as e: + self.error( + f"predictions have shape {dim} but" + f"prediction columns have [{self.prediction_columns}]" + ) + raise e + + assert predictions.shape[0] == 1 + return dim + + def _create_dictionary( + self, dim: int, predictions: np.ndarray + ) -> Dict[str, Any]: + """Transform predictions into a dictionary.""" + data = {} + for i in range(dim): + try: + assert len(predictions[:, i]) == 1 + data[ + self.model_name + "_" + self.prediction_columns[i] + ] = I3Double(float(predictions[:, i][0])) + except IndexError: + data[ + self.model_name + "_" + self.prediction_columns[i] + ] = I3Double(predictions[0]) + return data + + def _apply_model(self, data: Data) -> np.ndarray: + """Apply model to `Data` and case-handling.""" + if data is not None: + predictions = self._inference(data) + if isinstance(predictions, list): + predictions = predictions[0] + self.warning( + f"{self.__class__.__name__} assumes one Task " + f"but got {len(predictions)}. Only the first will" + " be used." + ) + else: + self.warning( + "At least one event has no pulses " + " - padding {self.prediction_columns} with NaN." + ) + predictions = np.repeat( + [np.nan], len(self.prediction_columns) + ).reshape(-1, len(self.prediction_columns)) + return predictions + + def _create_data_representation(self, frame: I3Frame) -> Data: + """Process Physics I3Frame into graph.""" + # Extract features + input_features = self._extract_feature_array_from_frame(frame) + # Prepare graph data + if len(input_features) > 0: + data = self._graph_definition( + input_features=input_features, + input_feature_names=self._features, + ) + return Batch.from_data_list([data]) + else: + return None + + def _extract_feature_array_from_frame(self, frame: I3Frame) -> np.array: + """Apply the I3FeatureExtractors to the I3Frame. + + Arguments: + frame: Physics I3Frame (PFrame) + + Returns: + array with pulsemap + """ + features = None + for i3extractor in self._i3_extractors: + feature_dict = i3extractor(frame) + features_pulsemap = np.array( + [feature_dict[key] for key in self._features] + ).T + if features is None: + features = features_pulsemap + else: + features = np.concatenate( + (features, features_pulsemap), axis=0 + ) + return features + + def _add_to_frame(self, frame: I3Frame, data: Dict[str, Any]) -> I3Frame: + """Add every field in data to I3Frame. + + Arguments: + frame: I3Frame (physics) + data: Dictionary containing content that will be written to frame. + + Returns: + frame: Same I3Frame as input, but with the new entries + """ + assert isinstance( + data, dict + ), f"data must be of type dict. Got {type(data)}" + for key in data.keys(): + if key not in frame: + frame.Put(key, data[key]) + return frame diff --git a/src/graphnet/models/components/embedding.py b/src/graphnet/models/components/embedding.py new file mode 100644 index 000000000..40145ad1a --- /dev/null +++ b/src/graphnet/models/components/embedding.py @@ -0,0 +1,143 @@ +"""Classes for performing embedding of input data.""" +import torch +import torch.nn as nn +from torch.functional import Tensor + +from pytorch_lightning import LightningModule + + +class SinusoidalPosEmb(LightningModule): + """Sinusoidal positional embeddings module. + + This module is from the kaggle competition 2nd place solution (see + arXiv:2310.15674): It performs what is called Fourier encoding or it's used + in the Attention is all you need arXiv:1706.03762. It can be seen as a soft + digitization of the input data + """ + + def __init__( + self, + dim: int = 16, + n_freq: int = 10000, + scaled: bool = False, + ): + """Construct `SinusoidalPosEmb`. + + Args: + dim: Embedding dimension. + n_freq: Number of frequencies. + scaled: Whether or not to scale the output. + """ + super().__init__() + if dim % 2 != 0: + raise ValueError(f"dim has to be even. Got: {dim}") + self.scale = ( + nn.Parameter(torch.ones(1) * dim**-0.5) if scaled else 1.0 + ) + self.dim = dim + self.n_freq = torch.Tensor([n_freq]) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + device = x.device + half_dim = self.dim / 2 + emb = torch.log(self.n_freq.to(device=device)) / half_dim + emb = torch.exp(torch.arange(half_dim, device=device) * (-emb)) + emb = x.unsqueeze(-1) * emb.unsqueeze(0) + emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1) + return emb * self.scale + + +class FourierEncoder(LightningModule): + """Fourier encoder module. + + This module incorporates sinusoidal positional embeddings and auxiliary + embeddings to process input sequences and produce meaningful + representations. + """ + + def __init__( + self, + seq_length: int = 128, + output_dim: int = 384, + scaled: bool = False, + ): + """Construct `FourierEncoder`. + + Args: + seq_length: Dimensionality of the base sinusoidal positional + embeddings. + output_dim: Output dimensionality of the final projection. + scaled: Whether or not to scale the embeddings. + """ + super().__init__() + self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled) + self.aux_emb = nn.Embedding(2, seq_length // 2) + self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled) + self.projection = nn.Sequential( + nn.Linear(6 * seq_length, 6 * seq_length), + nn.LayerNorm(6 * seq_length), + nn.GELU(), + nn.Linear(6 * seq_length, output_dim), + ) + + def forward( + self, + x: Tensor, + seq_length: Tensor, + ) -> Tensor: + """Forward pass.""" + length = torch.log10(seq_length.to(dtype=x.dtype)) + x = torch.cat( + [ + self.sin_emb(4096 * x[:, :, :3]).flatten(-2), # pos + self.sin_emb(1024 * x[:, :, 4]), # charge + self.sin_emb(4096 * x[:, :, 3]), # time + self.aux_emb(x[:, :, 5].long()), # auxiliary + self.sin_emb2(length) + .unsqueeze(1) + .expand(-1, max(seq_length), -1), + ], + -1, + ) + x = self.projection(x) + return x + + +class SpacetimeEncoder(LightningModule): + """Spacetime encoder module.""" + + def __init__( + self, + seq_length: int = 32, + ): + """Construct `SpacetimeEncoder`. + + This module calculates space-time interval between each pair of events + and generates sinusoidal positional embeddings to be added to input + sequences. + + Args: + seq_length: Dimensionality of the sinusoidal positional embeddings. + """ + super().__init__() + self.sin_emb = SinusoidalPosEmb(dim=seq_length) + self.projection = nn.Linear(seq_length, seq_length) + + def forward( + self, + x: Tensor, + # Lmax: Optional[int] = None, + ) -> Tensor: + """Forward pass.""" + pos = x[:, :, :3] + time = x[:, :, 3] + spacetime_interval = (pos[:, :, None] - pos[:, None, :]).pow(2).sum( + -1 + ) - ((time[:, :, None] - time[:, None, :]) * (3e4 / 500 * 3e-1)).pow(2) + four_distance = torch.sign(spacetime_interval) * torch.sqrt( + torch.abs(spacetime_interval) + ) + sin_emb = self.sin_emb(1024 * four_distance.clip(-4, 4)) + rel_attn = self.projection(sin_emb) + return rel_attn diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index 53f970286..bd2ed8daf 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -1,6 +1,6 @@ """Class(es) implementing layers to be used in `graphnet` models.""" -from typing import Any, Callable, Optional, Sequence, Union, List, Tuple +from typing import Any, Callable, Optional, Sequence, Union, List import torch from torch.functional import Tensor @@ -9,8 +9,10 @@ from torch_geometric.typing import Adj, PairTensor from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset +from torch_geometric.data import Data +import torch.nn as nn +from torch.nn.functional import linear from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer -from torch.nn.modules.normalization import LayerNorm from torch_geometric.utils import to_dense_batch from pytorch_lightning import LightningModule @@ -129,13 +131,13 @@ def __init__( """Construct `DynTrans`. Args: - nn: The MLP/torch.Module to be used within the `DynTrans`. layer_sizes: List of layer sizes to be used in `DynTrans`. aggr: Aggregation method to be used with `DynTrans`. features_subset: Subset of features in `Data.x` that should be used when dynamically performing the new graph clustering after the `EdgeConv` operation. Defaults to all features. - n_head: Number of heads to be used in the multiheadattention models. + n_head: Number of heads to be used in the multiheadattention + models. **kwargs: Additional features to be passed to `DynTrans`. """ # Check(s) @@ -151,17 +153,17 @@ def __init__( ): if ix == 0: nb_in *= 3 # edgeConv1 - layers.append(torch.nn.Linear(nb_in, nb_out)) - layers.append(torch.nn.LeakyReLU()) + layers.append(nn.Linear(nb_in, nb_out)) + layers.append(nn.LeakyReLU()) d_model = nb_out # Base class constructor - super().__init__(nn=torch.nn.Sequential(*layers), aggr=aggr, **kwargs) + super().__init__(nn=nn.Sequential(*layers), aggr=aggr, **kwargs) # Additional member variables self.features_subset = features_subset - self.norm1 = LayerNorm(d_model, eps=1e-5) # lNorm + self.norm1 = nn.LayerNorm(d_model, eps=1e-5) # lNorm # Transformer layer(s) encoder_layer = TransformerEncoderLayer( @@ -193,3 +195,402 @@ def forward( x = x[mask] return x + + +class DropPath(LightningModule): + """Drop paths (Stochastic Depth) per sample.""" + + def __init__( + self, + drop_prob: float = 0.0, + ): + """Construct `DropPath`. + + Args: + drop_prob: Probability of dropping a path during training. + If 0.0, no paths are dropped. Defaults to None. + """ + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self) -> str: + """Return extra representation of the module.""" + return "p={}".format(self.drop_prob) + + +class Mlp(LightningModule): + """Multi-Layer Perceptron (MLP) module.""" + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + activation: nn.Module = nn.GELU, + dropout_prob: float = 0.0, + ): + """Construct `Mlp`. + + Args: + in_features: Number of input features. + hidden_features: Number of hidden features. Defaults to None. + If None, it is set to the value of `in_features`. + out_features: Number of output features. Defaults to None. + If None, it is set to the value of `in_features`. + activation: Activation layer. Defaults to `nn.GELU`. + dropout_prob: Dropout probability. Defaults to 0.0. + """ + super().__init__() + if in_features <= 0: + raise ValueError( + f"in_features must be greater than 0, got in_features " + f"{in_features} instead" + ) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.input_projection = nn.Linear(in_features, hidden_features) + self.activation = activation() + self.output_projection = nn.Linear(hidden_features, out_features) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.input_projection(x) + x = self.activation(x) + x = self.output_projection(x) + x = self.dropout(x) + return x + + +class Block_rel(LightningModule): + """Implementation of BEiTv2 Block.""" + + def __init__( + self, + input_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + dropout: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + init_values: Optional[float] = None, + activation: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + attn_head_dim: Optional[int] = None, + ): + """Construct 'Block_rel'. + + Args: + input_dim: Dimension of the input tensor. + num_heads: Number of attention heads to use in the `Attention_rel` + layer. + mlp_ratio: Ratio of the hidden size of the feedforward network to + the input size in the `Mlp` layer. + qkv_bias: Whether or not to include bias terms in the query, key, + and value matrices in the `Attention_rel` layer. + qk_scale: Scaling factor for the dot product of the query and key + matrices in the `Attention_rel` layer. + dropout: Dropout probability to use in the `Mlp` layer. + attn_drop: Dropout probability to use in the `Attention_rel` layer. + drop_path: Probability of applying drop path regularization to the + output of the layer. + init_values: Initial value to use for the `gamma_1` and `gamma_2` + parameters if not `None`. + activation: Activation function to use in the `Mlp` layer. + norm_layer: Normalization layer to use. + attn_head_dim: Dimension of the attention head outputs in the + `Attention_rel` layer. + """ + super().__init__() + self.norm1 = norm_layer(input_dim) + self.attn = Attention_rel( + input_dim, + num_heads, + attn_drop=attn_drop, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_head_dim=attn_head_dim, + ) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(input_dim) + mlp_hidden_dim = int(input_dim * mlp_ratio) + self.mlp = Mlp( + in_features=input_dim, + hidden_features=mlp_hidden_dim, + activation=activation, + dropout_prob=dropout, + ) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones(input_dim), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + init_values * torch.ones(input_dim), requires_grad=True + ) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward( + self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None, + rel_pos_bias: Optional[Tensor] = None, + kv: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass.""" + if self.gamma_1 is None: + xn = self.norm1(x) + kv = xn if kv is None else self.norm1(kv) + x = x + self.drop_path( + self.attn( + xn, + kv, + kv, + rel_pos_bias=rel_pos_bias, + key_padding_mask=key_padding_mask, + ) + ) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + xn = self.norm1(x) + kv = xn if kv is None else self.norm1(kv) + x = x + self.drop_path( + self.gamma_1 + * self.drop_path( + self.attn( + xn, + kv, + kv, + rel_pos_bias=rel_pos_bias, + key_padding_mask=key_padding_mask, + ) + ) + ) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class Attention_rel(LightningModule): + """Attention mechanism with relative position bias.""" + + def __init__( + self, + input_dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, + ): + """Construct 'Attention_rel'. + + Args: + input_dim: Dimension of the input tensor. + num_heads: the number of attention heads to use (default: 8) + qkv_bias: whether to add bias to the query, key, and value + projections. Defaults to False. + qk_scale: a scaling factor that multiplies the dot product of query + and key vectors. Defaults to None. If None, computed as + :math: `head_dim^(-1/2)`. + attn_drop: the dropout probability for the attention weights. + Defaults to 0.0. + proj_drop: the dropout probability for the output of the attention + module. Defaults to 0.0. + attn_head_dim: the feature dimensionality of each attention head. + Defaults to None. If None, computed as `dim // num_heads`. + """ + if input_dim <= 0 or num_heads <= 0: + raise ValueError( + f"dim and num_heads must be greater than 0," + f" got input_dim={input_dim} and num_heads={num_heads} instead" + ) + + super().__init__() + self.num_heads = num_heads + head_dim = attn_head_dim or input_dim // num_heads + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.proj_q = nn.Linear(input_dim, all_head_dim, bias=False) + self.proj_k = nn.Linear(input_dim, all_head_dim, bias=False) + self.proj_v = nn.Linear(input_dim, all_head_dim, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, input_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + rel_pos_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass.""" + batch_size, event_length, _ = q.shape + + q = linear(input=q, weight=self.proj_q.weight, bias=self.q_bias) + q = q.reshape(batch_size, event_length, self.num_heads, -1).permute( + 0, 2, 1, 3 + ) + k = linear(input=k, weight=self.proj_k.weight, bias=None) + k = k.reshape(batch_size, k.shape[1], self.num_heads, -1).permute( + 0, 2, 1, 3 + ) + v = linear(input=v, weight=self.proj_v.weight, bias=self.v_bias) + v = v.reshape(batch_size, v.shape[1], self.num_heads, -1).permute( + 0, 2, 1, 3 + ) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if rel_pos_bias is not None: + bias = torch.einsum("bhic,bijc->bhij", q, rel_pos_bias) + attn = attn + bias + if key_padding_mask is not None: + assert ( + key_padding_mask.dtype == torch.float32 + or key_padding_mask.dtype == torch.float16 + ), "incorrect mask dtype" + bias = torch.min( + key_padding_mask[:, None, :], key_padding_mask[:, :, None] + ) + bias[ + torch.max( + key_padding_mask[:, None, :], key_padding_mask[:, :, None] + ) + < 0 + ] = 0 + attn = attn + bias.unsqueeze(1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2) + if rel_pos_bias is not None: + x = x + torch.einsum("bhij,bijc->bihc", attn, rel_pos_bias) + x = x.reshape(batch_size, event_length, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(LightningModule): + """Transformer block.""" + + def __init__( + self, + input_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + init_values: Optional[float] = None, + activation: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + ): + """Construct 'Block'. + + Args: + input_dim: Dimension of the input tensor. + num_heads: Number of attention heads to use in the + `MultiheadAttention` layer. + mlp_ratio: Ratio of the hidden size of the feedforward network to + the input size in the `Mlp` layer. + dropout: Dropout probability to use in the `Mlp` layer. + attn_drop: Dropout probability to use in the `MultiheadAttention` + layer. + drop_path: Probability of applying drop path regularization to the + output of the layer. + init_values: Initial value to use for the `gamma_1` and `gamma_2` + parameters if not `None`. + activation: Activation function to use in the `Mlp` layer. + norm_layer: Normalization layer to use. + """ + super().__init__() + self.norm1 = norm_layer(input_dim) + self.attn = nn.MultiheadAttention( + input_dim, num_heads, dropout=attn_drop, batch_first=True + ) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + self.norm2 = norm_layer(input_dim) + mlp_hidden_dim = int(input_dim * mlp_ratio) + self.mlp = Mlp( + in_features=input_dim, + hidden_features=mlp_hidden_dim, + activation=activation, + dropout_prob=dropout, + ) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((input_dim)), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((input_dim)), requires_grad=True + ) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward( + self, + x: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass.""" + if self.gamma_1 is None: + xn = self.norm1(x) + x = x + self.drop_path( + self.attn( + xn, + xn, + xn, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + ) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + xn = self.norm1(x) + x = x + self.drop_path( + self.gamma_1 + * self.attn( + xn, + xn, + xn, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + ) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py index 691b94fc7..c39240d7f 100644 --- a/src/graphnet/models/detector/icecube.py +++ b/src/graphnet/models/detector/icecube.py @@ -28,6 +28,7 @@ def feature_map(self) -> Dict[str, Callable]: "charge": self._charge, "rde": self._rde, "pmt_area": self._pmt_area, + "hlc": self._identity, } return feature_map @@ -85,6 +86,7 @@ def feature_map(self) -> Dict[str, Callable]: "charge": self._identity, "rde": self._rde, "pmt_area": self._pmt_area, + "hlc": self._identity, } return feature_map @@ -131,6 +133,7 @@ def feature_map(self) -> Dict[str, Callable]: "pmt_dir_y": self._identity, "pmt_dir_z": self._identity, "dom_type": self._dom_type, + "hlc": self._identity, } return feature_map diff --git a/src/graphnet/models/gnn/RNN_tito.py b/src/graphnet/models/gnn/RNN_tito.py new file mode 100644 index 000000000..75f3a04fc --- /dev/null +++ b/src/graphnet/models/gnn/RNN_tito.py @@ -0,0 +1,130 @@ +"""RNN_DynEdge model implementation.""" +from typing import List, Optional, Tuple, Union + +import torch +from graphnet.models.gnn.gnn import GNN +from graphnet.models.gnn.dynedge import DynEdge +from graphnet.models.gnn.dynedge_kaggle_tito import DynEdgeTITO +from graphnet.models.rnn.node_rnn import Node_RNN + +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data + + +class RNN_TITO(GNN): + """The RNN_TITO model class. + + Combines the Node_RNN and DynEdgeTITO models, intended for data with large + amount of DOM activations per event. This model works only with non- + standard dataset specific to the Node_RNN model see Node_RNN for more + details. + """ + + @save_model_config + def __init__( + self, + nb_inputs: int, + time_series_columns: List[int], + *, + nb_neighbours: int = 8, + rnn_layers: int = 2, + rnn_hidden_size: int = 64, + rnn_dropout: float = 0.5, + features_subset: Optional[List[int]] = None, + dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: List[str] = ["max"], + embedding_dim: Optional[int] = None, + n_head: int = 16, + use_global_features: bool = True, + use_post_processing_layers: bool = True, + ): + """Initialize the RNN_DynEdge model. + + Args: + nb_inputs (int): Number of input features. + time_series_columns (List[int]): The indices of the input data that should be treated as time series data. The first index should be the charge column. + nb_neighbours (int, optional): Number of neighbours to consider. + Defaults to 8. + rnn_layers (int, optional): Number of RNN layers. + Defaults to 1. + rnn_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN. + Defaults to 64. + rnn_dropout (float, optional): Dropout to use in the RNN. Defaults to 0.5. + features_subset (List[int], optional): The subset of latent + features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] + dyntrans_layer_sizes (List[Tuple[int, ...]], optional): List of tuples representing the sizes of the hidden layers of the DynTrans model. + post_processing_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the post-processing model. + readout_layer_sizes (List[int], optional): List of integers representing the sizes of the hidden layers of the readout model. + global_pooling_schemes (Union[str, List[str]], optional): Pooling schemes to use. Defaults to None. + embedding_dim (int, optional): Embedding dimension of the RNN. Defaults to None ie. no embedding. + n_head (int, optional): Number of heads to use in the DynTrans model. Defaults to 16. + use_global_features (bool, optional): Whether to use global features after pooling. Defaults to True. + use_post_processing_layers (bool, optional): Whether to use post-processing layers after the DynTrans layers. Defaults to True. + """ + self._nb_neighbours = nb_neighbours + self._nb_inputs = nb_inputs + self._rnn_layers = rnn_layers + self._rnn_hidden_size = rnn_hidden_size + self._rnn_dropout = rnn_dropout + self._embedding_dim = embedding_dim + self._n_head = n_head + self._use_global_features = use_global_features + self._use_post_processing_layers = use_post_processing_layers + + self._features_subset = features_subset + if dyntrans_layer_sizes is None: + dyntrans_layer_sizes = [ + (256, 256), + (256, 256), + (256, 256), + (256, 256), + ] + else: + dyntrans_layer_sizes = [ + tuple(layer_sizes) for layer_sizes in dyntrans_layer_sizes + ] + + self._dyntrans_layer_sizes = dyntrans_layer_sizes + self._post_processing_layer_sizes = post_processing_layer_sizes + self._global_pooling_schemes = global_pooling_schemes + if readout_layer_sizes is None: + readout_layer_sizes = [ + 256, + 128, + ] + self._readout_layer_sizes = readout_layer_sizes + + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + self._rnn = Node_RNN( + nb_inputs=2, + hidden_size=self._rnn_hidden_size, + num_layers=self._rnn_layers, + time_series_columns=time_series_columns, + nb_neighbours=self._nb_neighbours, + features_subset=self._features_subset, + dropout=self._rnn_dropout, + embedding_dim=self._embedding_dim, + ) + + self._dynedge_tito = DynEdgeTITO( + nb_inputs=self._rnn_hidden_size + 5, + dyntrans_layer_sizes=self._dyntrans_layer_sizes, + features_subset=self._features_subset, + global_pooling_schemes=self._global_pooling_schemes, + use_global_features=self._use_global_features, + use_post_processing_layers=self._use_post_processing_layers, + post_processing_layer_sizes=self._post_processing_layer_sizes, + readout_layer_sizes=self._readout_layer_sizes, + n_head=self._n_head, + nb_neighbours=self._nb_neighbours, + ) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass of the RNN and tito model.""" + data = self._rnn(data) + readout = self._dynedge_tito(data) + + return readout diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py index 2abe3d358..9b4de8238 100644 --- a/src/graphnet/models/gnn/__init__.py +++ b/src/graphnet/models/gnn/__init__.py @@ -4,3 +4,5 @@ from .dynedge import DynEdge from .dynedge_jinst import DynEdgeJINST from .dynedge_kaggle_tito import DynEdgeTITO +from .RNN_tito import RNN_TITO +from .icemix import DeepIce diff --git a/src/graphnet/models/gnn/dynedge.py b/src/graphnet/models/gnn/dynedge.py index 9ea93f9ce..e4d0160e4 100644 --- a/src/graphnet/models/gnn/dynedge.py +++ b/src/graphnet/models/gnn/dynedge.py @@ -1,5 +1,5 @@ """Implementation of the DynEdge GNN model architecture.""" -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Callable, Tuple, Union import torch from torch import Tensor, LongTensor @@ -32,6 +32,9 @@ def __init__( readout_layer_sizes: Optional[List[int]] = None, global_pooling_schemes: Optional[Union[str, List[str]]] = None, add_global_variables_after_pooling: bool = False, + activation_layer: Callable = None, + add_norm_layer: bool = False, + skip_readout: bool = False, ): """Construct `DynEdge`. @@ -65,6 +68,11 @@ def __init__( after global pooling. The alternative is to added (distribute) them to the individual nodes before any convolutional operations. + activation_layer: The activation function to use in the model. + add_norm_layer: Whether to add a normalization layer after each + linear layer. + skip_readout: Whether to skip the readout layer(s). If `True`, the + output of the last post-processing layer is returned directly. """ # Latent feature subset for computing nearest neighbours in DynEdge. if features_subset is None: @@ -149,15 +157,20 @@ def __init__( add_global_variables_after_pooling ) + if activation_layer is None: + activation_layer = torch.nn.ReLU() + # Base class constructor super().__init__(nb_inputs, self._readout_layer_sizes[-1]) # Remaining member variables() - self._activation = torch.nn.LeakyReLU() + self._activation = activation_layer self._nb_inputs = nb_inputs self._nb_global_variables = 5 + nb_inputs self._nb_neighbours = nb_neighbours self._features_subset = features_subset + self._add_norm_layer = add_norm_layer + self._skip_readout = skip_readout self._construct_layers() @@ -179,6 +192,8 @@ def _construct_layers(self) -> None: if ix == 0: nb_in *= 2 layers.append(torch.nn.Linear(nb_in, nb_out)) + if self._add_norm_layer: + layers.append(torch.nn.LayerNorm(nb_out)) layers.append(self._activation) conv_layer = DynEdgeConv( @@ -203,6 +218,8 @@ def _construct_layers(self) -> None: ) for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + if self._add_norm_layer: + post_processing_layers.append(torch.nn.LayerNorm(nb_out)) post_processing_layers.append(self._activation) self._post_processing = torch.nn.Sequential(*post_processing_layers) @@ -307,19 +324,20 @@ def forward(self, data: Data) -> Tensor: # Post-processing x = self._post_processing(x) - # (Optional) Global pooling - if self._global_pooling_schemes: - x = self._global_pooling(x, batch=batch) - if self._add_global_variables_after_pooling: - x = torch.cat( - [ - x, - global_variables, - ], - dim=1, - ) - - # Read-out - x = self._readout(x) + if not self._skip_readout: + # (Optional) Global pooling + if self._global_pooling_schemes: + x = self._global_pooling(x, batch=batch) + if self._add_global_variables_after_pooling: + x = torch.cat( + [ + x, + global_variables, + ], + dim=1, + ) + + # Read-out + x = self._readout(x) return x diff --git a/src/graphnet/models/gnn/dynedge_kaggle_tito.py b/src/graphnet/models/gnn/dynedge_kaggle_tito.py index 78b5aebe5..12490e808 100644 --- a/src/graphnet/models/gnn/dynedge_kaggle_tito.py +++ b/src/graphnet/models/gnn/dynedge_kaggle_tito.py @@ -39,6 +39,10 @@ def __init__( global_pooling_schemes: List[str] = ["max"], use_global_features: bool = True, use_post_processing_layers: bool = True, + post_processing_layer_sizes: List[int] = None, + readout_layer_sizes: Optional[List[int]] = None, + n_head: int = 8, + nb_neighbours: int = 8, ): """Construct `DynEdgeTITO`. @@ -53,8 +57,12 @@ def __init__( global_pooling_schemes: The list global pooling schemes to use. Options are: "min", "max", "mean", and "sum". use_global_features: Whether to use global features after pooling. - use_post_processing_layers: Whether to use post-processing layers - after the `DynTrans` layers. + use_post_processing_layers: Whether to use post-processing layers after the `DynTrans` layers. + post_processing_layer_sizes: The layer sizes used in the post-processing layers. Defaults to [336, 256]. + readout_layer_sizes: The layer sizes used in the readout layers. Defaults to [256, 128]. + n_head: The number of heads to use in the `DynTrans` layer. + nb_neighbours: The number of neighbours to use in the `DynTrans` + layer. """ # DynTrans layer sizes if dyntrans_layer_sizes is None: @@ -88,18 +96,20 @@ def __init__( self._dyntrans_layer_sizes = dyntrans_layer_sizes # Post-processing layer sizes - post_processing_layer_sizes = [ - 336, - 256, - ] + if post_processing_layer_sizes is None: + post_processing_layer_sizes = [ + 336, + 256, + ] self._post_processing_layer_sizes = post_processing_layer_sizes # Read-out layer sizes - readout_layer_sizes = [ - 256, - 128, - ] + if readout_layer_sizes is None: + readout_layer_sizes = [ + 256, + 128, + ] self._readout_layer_sizes = readout_layer_sizes @@ -129,10 +139,11 @@ def __init__( self._activation = torch.nn.LeakyReLU() self._nb_inputs = nb_inputs self._nb_global_variables = 5 + nb_inputs - self._nb_neighbours = 8 + self._nb_neighbours = nb_neighbours self._features_subset = features_subset or [0, 1, 2, 3] self._use_global_features = use_global_features self._use_post_processing_layers = use_post_processing_layers + self._n_head = n_head self._construct_layers() def _construct_layers(self) -> None: @@ -147,7 +158,7 @@ def _construct_layers(self) -> None: [nb_latent_features] + list(sizes), aggr="max", features_subset=self._features_subset, - n_head=8, + n_head=self._n_head, ) self._conv_layers.append(conv_layer) nb_latent_features = sizes[-1] diff --git a/src/graphnet/models/gnn/icemix.py b/src/graphnet/models/gnn/icemix.py new file mode 100644 index 000000000..8ecf0bf62 --- /dev/null +++ b/src/graphnet/models/gnn/icemix.py @@ -0,0 +1,159 @@ +"""Implementation of IceMix architecture used in. + + IceCube - Neutrinos in Deep Ice +Reconstruct the direction of neutrinos from the Universe to the South Pole + +Kaggle competition. + +Solution by DrHB: https://github.com/DrHB/icecube-2nd-place +""" +import torch +import torch.nn as nn +from typing import Set, Dict, Any, List + +from graphnet.models.components.layers import ( + Block_rel, + Block, +) +from graphnet.models.components.embedding import ( + FourierEncoder, + SpacetimeEncoder, +) +from graphnet.models.gnn.dynedge import DynEdge +from graphnet.models.gnn.gnn import GNN +from graphnet.models.utils import array_to_sequence + +from torch_geometric.utils import to_dense_batch +from torch_geometric.data import Data +from torch import Tensor + + +class DeepIce(GNN): + """DeepIce model.""" + + def __init__( + self, + hidden_dim: int = 384, + seq_length: int = 128, + depth: int = 12, + head_size: int = 32, + depth_rel: int = 4, + n_rel: int = 1, + scaled_emb: bool = False, + include_dynedge: bool = False, + dynedge_args: Dict[str, Any] = None, + ): + """Construct `DeepIce`. + + Args: + hidden_dim: The latent feature dimension. + seq_length: The base feature dimension. + depth: The depth of the transformer. + head_size: The size of the attention heads. + depth_rel: The depth of the relative transformer. + n_rel: The number of relative transformer layers to use. + scaled_emb: Whether to scale the sinusoidal positional embeddings. + include_dynedge: If True, pulse-level predictions from `DynEdge` + will be added as features to the model. + dynedge_args: Initialization arguments for DynEdge. If not + provided, DynEdge will be initialized with the original Kaggle + Competition settings. If `include_dynedge` is False, this + argument have no impact. + """ + super().__init__(seq_length, hidden_dim) + fourier_out_dim = hidden_dim // 2 if include_dynedge else hidden_dim + self.fourier_ext = FourierEncoder( + seq_length, fourier_out_dim, scaled=scaled_emb + ) + self.rel_pos = SpacetimeEncoder(head_size) + self.sandwich = nn.ModuleList( + [ + Block_rel( + input_dim=hidden_dim, num_heads=hidden_dim // head_size + ) + for _ in range(depth_rel) + ] + ) + self.cls_token = nn.Linear(hidden_dim, 1, bias=False) + self.blocks = nn.ModuleList( + [ + Block( + input_dim=hidden_dim, + num_heads=hidden_dim // head_size, + mlp_ratio=4, + drop_path=0.0 * (i / (depth - 1)), + init_values=1, + ) + for i in range(depth) + ] + ) + self.n_rel = n_rel + + if include_dynedge and dynedge_args is None: + self.warning_once("Running with default DynEdge settings") + self.dyn_edge = DynEdge( + nb_inputs=9, + nb_neighbours=9, + post_processing_layer_sizes=[336, hidden_dim // 2], + dynedge_layer_sizes=[ + (128, 256), + (336, 256), + (336, 256), + (336, 256), + ], + global_pooling_schemes=None, + activation_layer=nn.GELU(), + add_norm_layer=True, + skip_readout=True, + ) + elif include_dynedge and not (dynedge_args is None): + self.dyn_edge = DynEdge(**dynedge_args) + + self.include_dynedge = include_dynedge + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + """cls_tocken should not be subject to weight decay during training.""" + return {"cls_token"} + + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass.""" + x0, mask, seq_length = array_to_sequence( + data.x, data.batch, padding_value=0 + ) + x = self.fourier_ext(x0, seq_length) + rel_pos_bias = self.rel_pos(x0) + batch_size = mask.shape[0] + if self.include_dynedge: + graph = self.dyn_edge(data) + graph, _ = to_dense_batch(graph, data.batch) + x = torch.cat([x, graph], 2) + + attn_mask = torch.zeros(mask.shape, device=mask.device) + attn_mask[~mask] = -torch.inf + + for i, blk in enumerate(self.sandwich): + x = blk(x, attn_mask, rel_pos_bias) + if i + 1 == self.n_rel: + rel_pos_bias = None + + mask = torch.cat( + [ + torch.ones( + batch_size, 1, dtype=mask.dtype, device=mask.device + ), + mask, + ], + 1, + ) + attn_mask = torch.zeros(mask.shape, device=mask.device) + attn_mask[~mask] = -torch.inf + cls_token = self.cls_token.weight.unsqueeze(0).expand( + batch_size, -1, -1 + ) + x = torch.cat([cls_token, x], 1) + + for blk in self.blocks: + x = blk(x, None, attn_mask) + + return x[:, 0] diff --git a/src/graphnet/models/graphs/edges/minkowski.py b/src/graphnet/models/graphs/edges/minkowski.py index 5d1134ec5..2526de1cb 100644 --- a/src/graphnet/models/graphs/edges/minkowski.py +++ b/src/graphnet/models/graphs/edges/minkowski.py @@ -69,12 +69,13 @@ def _construct_edges(self, graph: Data) -> Data: row = [] col = [] for batch in range(x.shape[0]): + x_masked = x[batch][mask[batch]] distance_mat = compute_minkowski_distance_mat( - x_masked := x[batch][mask[batch]], - x_masked, - self.c, - self.space_coords, - self.time_coord, + x=x_masked, + y=x_masked, + c=self.c, + space_coords=self.space_coords, + time_coord=self.time_coord, ) num_points = x_masked.shape[0] num_edges = min(self.nb_nearest_neighbours, num_points) diff --git a/src/graphnet/models/graphs/nodes/__init__.py b/src/graphnet/models/graphs/nodes/__init__.py index 0119d2b98..64bcf70ba 100644 --- a/src/graphnet/models/graphs/nodes/__init__.py +++ b/src/graphnet/models/graphs/nodes/__init__.py @@ -5,4 +5,10 @@ and their features. """ -from .nodes import NodeDefinition, NodesAsPulses, PercentileClusters +from .nodes import ( + NodeDefinition, + NodesAsPulses, + PercentileClusters, + NodeAsDOMTimeSeries, + IceMixNodes, +) diff --git a/src/graphnet/models/graphs/nodes/nodes.py b/src/graphnet/models/graphs/nodes/nodes.py index fa0400b97..5d47a8487 100644 --- a/src/graphnet/models/graphs/nodes/nodes.py +++ b/src/graphnet/models/graphs/nodes/nodes.py @@ -1,6 +1,6 @@ """Class(es) for building/connecting graphs.""" -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from abc import abstractmethod import torch @@ -11,9 +11,13 @@ from graphnet.models.graphs.utils import ( cluster_summarize_with_percentiles, identify_indices, + lex_sort, + ice_transparency, ) from copy import deepcopy +import numpy as np + class NodeDefinition(Model): # pylint: disable=too-few-public-methods """Base class for graph building.""" @@ -211,3 +215,212 @@ def _construct_nodes(self, x: torch.Tensor) -> Data: raise AttributeError return Data(x=torch.tensor(array)) + + +class NodeAsDOMTimeSeries(NodeDefinition): + """Represent each node as a DOM with time and charge time series data.""" + + def __init__( + self, + keys: List[str] = [ + "dom_x", + "dom_y", + "dom_z", + "dom_time", + "charge", + ], + id_columns: List[str] = ["dom_x", "dom_y", "dom_z"], + time_column: str = "dom_time", + charge_column: str = "charge", + max_activations: Optional[int] = None, + ) -> None: + """Construct `NodeAsDOMTimeSeries`. + + Args: + keys: Names of features in the data (in order). + id_columns: List of columns that uniquely identify a DOM. + time_column: Name of time column. + charge_column: Name of charge column. + max_activations: Maximum number of activations to include in the time series. + """ + self._keys = keys + super().__init__(input_feature_names=self._keys) + self._id_columns = [self._keys.index(key) for key in id_columns] + self._time_index = self._keys.index(time_column) + try: + self._charge_index: Optional[int] = self._keys.index(charge_column) + except ValueError: + self.warning( + "Charge column with name {} not found. Running without.".format( + charge_column + ) + ) + + self._charge_index = None + + self._max_activations = max_activations + + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + return input_feature_names + ["new_node_col"] + + def _construct_nodes(self, x: torch.Tensor) -> Data: + """Construct nodes from raw node features ´x´.""" + # Cast to Numpy + x = x.numpy() + # if there is no charge column add a dummy column of zeros with the same shape as the time column + if self._charge_index is None: + charge_index: int = len(self._keys) + x = np.insert(x, charge_index, np.zeros(x.shape[0]), axis=1) + else: + charge_index = self._charge_index + + # Sort by time + x = x[x[:, self._time_index].argsort()] + # Undo log10 scaling so we can sum charges + x[:, charge_index] = np.power(10, x[:, charge_index]) + # Shift time to start at 0 + x[:, self._time_index] -= np.min(x[:, self._time_index]) + # Group pulses on the same DOM + x = lex_sort(x, self._id_columns) + + unique_sensors, counts = np.unique( + x[:, self._id_columns], axis=0, return_counts=True + ) + + sort_this = np.concatenate( + [unique_sensors, counts.reshape(-1, 1)], axis=1 + ) + sort_this = lex_sort(x=sort_this, cluster_columns=self._id_columns) + unique_sensors = sort_this[:, 0 : unique_sensors.shape[1]] + counts = sort_this[:, unique_sensors.shape[1] :].flatten().astype(int) + + new_node_col = np.zeros(x.shape[0]) + new_node_col[counts.cumsum()[:-1]] = 1 + new_node_col[0] = 1 + x = np.column_stack([x, new_node_col]) + + return Data(x=torch.tensor(x)) + + +class IceMixNodes(NodeDefinition): + """Calculate ice properties and perform random sampling. + + Ice properties are calculated based on the z-coordinate of the pulse. For + each event, a random sampling is performed to keep the number of pulses + below a maximum number of pulses if n_pulses is over the limit. + """ + + def __init__( + self, + input_feature_names: Optional[List[str]] = None, + max_pulses: int = 768, + z_name: str = "dom_z", + hlc_name: str = "hlc", + ) -> None: + """Construct `IceMixNodes`. + + Args: + input_feature_names: Column names for input features. Minimum + required features are z coordinate and hlc column names. + max_pulses: Maximum number of pulses to keep in the event. + z_name: Name of the z-coordinate column. + hlc_name: Name of the `Hard Local Coincidence Check` column. + """ + super().__init__(input_feature_names=input_feature_names) + + if input_feature_names is None: + input_feature_names = [ + "dom_x", + "dom_y", + "dom_z", + "dom_time", + "charge", + "hlc", + "rde", + ] + + if z_name not in input_feature_names: + raise ValueError( + f"z name {z_name} not found in " + f"input_feature_names {input_feature_names}" + ) + if hlc_name not in input_feature_names: + raise ValueError( + f"hlc name {hlc_name} not found in " + f"input_feature_names {input_feature_names}" + ) + + self.all_features = input_feature_names + [ + "scatt_lenght", + "abs_lenght", + ] + + self.feature_indexes = { + feat: self.all_features.index(feat) for feat in input_feature_names + } + + self.f_scattering, self.f_absoprtion = ice_transparency() + + self.input_feature_names = input_feature_names + self.n_features = len(self.all_features) + self.max_length = max_pulses + self.z_name = z_name + self.hlc_name = hlc_name + + def _define_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + return self.all_features + + def _add_ice_properties( + self, graph: torch.Tensor, x: torch.Tensor, ids: List[int] + ) -> torch.Tensor: + + graph[: len(ids), -2] = torch.tensor( + self.f_scattering(x[ids, self.feature_indexes[self.z_name]]) + ) + graph[: len(ids), -1] = torch.tensor( + self.f_absoprtion(x[ids, self.feature_indexes[self.z_name]]) + ) + return graph + + def _pulse_sampler( + self, x: torch.Tensor, event_length: int + ) -> torch.Tensor: + + if event_length < self.max_length: + ids = torch.arange(event_length) + else: + ids = torch.randperm(event_length) + auxiliary_n = torch.nonzero( + x[:, self.feature_indexes[self.hlc_name]] == 0 + ).squeeze(1) + auxiliary_p = torch.nonzero( + x[:, self.feature_indexes[self.hlc_name]] == 1 + ).squeeze(1) + ids_n = ids[auxiliary_n][: min(self.max_length, len(auxiliary_n))] + ids_p = ids[auxiliary_p][ + : min(self.max_length - len(ids_n), len(auxiliary_p)) + ] + ids = torch.cat([ids_n, ids_p]).sort().values + return ids + + def _construct_nodes(self, x: torch.Tensor) -> Tuple[Data, List[str]]: + + event_length = x.shape[0] + x[:, self.feature_indexes[self.hlc_name]] = torch.logical_not( + x[:, self.feature_indexes[self.hlc_name]] + ) # hlc in kaggle was flipped + ids = self._pulse_sampler(x, event_length) + event_length = min(self.max_length, event_length) + + graph = torch.zeros([event_length, self.n_features]) + for idx, feature in enumerate( + self.all_features[: self.n_features - 2] + ): + graph[:event_length, idx] = x[ids, self.feature_indexes[feature]] + + graph = self._add_ice_properties(graph, x, ids) # ice properties + return Data(x=graph) diff --git a/src/graphnet/models/graphs/utils.py b/src/graphnet/models/graphs/utils.py index ccd861783..63868d7a6 100644 --- a/src/graphnet/models/graphs/utils.py +++ b/src/graphnet/models/graphs/utils.py @@ -1,7 +1,12 @@ """Utility functions for construction of graphs.""" from typing import List, Tuple +import os import numpy as np +import pandas as pd +from scipy.interpolate import interp1d +from sklearn.preprocessing import RobustScaler +from graphnet.constants import DATA_DIR def lex_sort(x: np.array, cluster_columns: List[int]) -> np.ndarray: @@ -158,3 +163,31 @@ def cluster_summarize_with_percentiles( ) return array + + +def ice_transparency() -> Tuple[interp1d, interp1d]: + """Return interpolation functions for optical properties of IceCube. + + NOTE: The resulting interpolation functions assumes that the + Z-coordinate of pulse are scaled as `z = z/500`. + Any deviation from this scaling method results in inaccurate results. + + Returns: + f_scattering: Function that takes a normalized depth and returns the + corresponding normalized scattering length. + f_absorption: Function that takes a normalized depth and returns the + corresponding normalized absorption length. + """ + # Data from page 31 of https://arxiv.org/pdf/1301.5361.pdf + df = pd.read_parquet( + os.path.join(DATA_DIR, "ice_properties/ice_transparency.parquet"), + ) + df["z"] = df["depth"] - 1950 + df["z_norm"] = df["z"] / 500 + df[ + ["scattering_len_norm", "absorption_len_norm"] + ] = RobustScaler().fit_transform(df[["scattering_len", "absorption_len"]]) + + f_scattering = interp1d(df["z_norm"], df["scattering_len_norm"]) + f_absorption = interp1d(df["z_norm"], df["absorption_len_norm"]) + return f_scattering, f_absorption diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py index b4c9ca557..7c51c7952 100644 --- a/src/graphnet/models/model.py +++ b/src/graphnet/models/model.py @@ -23,6 +23,8 @@ class Model( ): """Base class for all components in graphnet.""" + verbose_print = True + @staticmethod def _get_batch_size(data: List[Data]) -> int: return sum([torch.numel(torch.unique(d.batch)) for d in data]) @@ -56,7 +58,7 @@ def save_state_dict(self, path: str) -> None: torch.save(state_dict, path) self.info(f"Model state_dict saved to {path}") - def load_state_dict( + def load_state_dict( # type: ignore[override] self, path: Union[str, Dict], **kargs: Optional[Any] ) -> "Model": # pylint: disable=arguments-differ """Load model `state_dict` from `path`.""" @@ -104,3 +106,41 @@ def from_config( # type: ignore[override] ), f"Argument `source` of type ({type(source)}) is not a `ModelConfig" return source._construct_model(trust, load_modules) + + def set_verbose_print_recursively(self, verbose_print: bool) -> None: + """Set verbose_print recursively for all Model modules.""" + for module in self.modules(): + if isinstance(module, Model): + module.verbose_print = verbose_print + self.verbose_print = verbose_print + + def extra_repr(self) -> str: + """Provide a more detailed description of the object print. + + Returns: + str: A string representation containing detailed information + about the object. + """ + return self._extra_repr() if self.verbose_print else "" + + def _extra_repr(self) -> str: + """Detailed information about the object.""" + return f"""{self.__class__.__name__}(\n{self.extra_repr_recursive( + self._config.__dict__)})""" + + def extra_repr_recursive(self, dictionary: dict, indent: int = 4) -> str: + """Recursively format a dictionary for extra_repr.""" + result = "{\n" + for key, value in dictionary.items(): + if key == "class_name": + continue + result += " " * indent + f"'{key}': " + if isinstance(value, dict): + result += self.extra_repr_recursive(value, indent + 4) + elif isinstance(value, Model): + result += value.__repr__() + else: + result += repr(value) + result += ",\n" + result += " " * (indent - 4) + "}" + return result diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml index d9ca5b4f1..b3adc498b 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/SplitInIcePulses_cleaner/SplitInIcePulses_cleaner_config.yml @@ -19,7 +19,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml index 70c82f90d..fdb414b37 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_direction/neutrino_direction_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml index b75cbf87d..9379e16fd 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_vs_muon_classifier/neutrino_vs_muon_classifier_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml index 29f8e3e63..937152aa2 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/neutrino_zenith/neutrino_zenith_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml index 916c7db62..8091df6e3 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/total_neutrino_energy/total_neutrino_energy_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml index d233ed7e9..3a5c52631 100644 --- a/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml +++ b/src/graphnet/models/pretrained/icecube/upgrade/QUESO/track_vs_cascade_classifier/track_vs_cascade_classifier_config.yml @@ -25,7 +25,7 @@ arguments: ModelConfig: arguments: {} class_name: NodesAsPulses - input_feature_names: null + input_feature_names: ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area', 'string', 'pmt_number', 'dom_number', 'pmt_dir_x', 'pmt_dir_y', 'pmt_dir_z', 'dom_type'] class_name: KNNGraph optimizer_class: '!class torch.optim.adam Adam' optimizer_kwargs: null diff --git a/src/graphnet/models/rnn/__init__.py b/src/graphnet/models/rnn/__init__.py new file mode 100644 index 000000000..21d29d7e7 --- /dev/null +++ b/src/graphnet/models/rnn/__init__.py @@ -0,0 +1,3 @@ +"""Recurrent neural network specific modules.""" + +from .node_rnn import Node_RNN diff --git a/src/graphnet/models/rnn/node_rnn.py b/src/graphnet/models/rnn/node_rnn.py new file mode 100644 index 000000000..45f7d643d --- /dev/null +++ b/src/graphnet/models/rnn/node_rnn.py @@ -0,0 +1,136 @@ +"""Implementation of the NodeTimeRNN model. + +(cannot be used as a standalone model) +""" +import torch + +from graphnet.models.gnn.gnn import GNN +from graphnet.utilities.config import save_model_config +from torch_geometric.data import Data +from torch_geometric.nn.pool import knn_graph +from typing import List, Optional + + +from graphnet.models.components.embedding import SinusoidalPosEmb + + +class Node_RNN(GNN): + """Implementation of the Node RNN model architecture. + + The model takes as input the typical DOM data format and transforms it into + a time series of DOM activations pr. DOM. before applying a RNN layer and + outputting the an RNN output for each DOM. This model is in its current + state not intended to be used as a standalone model. Furthermore, it needs + to be used with a time-series dataset object, where the last column in x is + a special column that is used to seperate the activation into time series + per dom per batch. + """ + + @save_model_config + def __init__( + self, + nb_inputs: int, + hidden_size: int, + num_layers: int, + time_series_columns: List[int], + nb_neighbours: int = 8, + features_subset: Optional[List[int]] = None, + dropout: float = 0.5, + embedding_dim: int = 0, + ) -> None: + """Construct `Node_RNN`. + + Args: + nb_inputs: Number of features in the input data. + hidden_size: Number of features for the RNN output and hidden layers. + num_layers: Number of layers in the RNN. + time_series_columns: The indices of the input data that should be treated as time series data. The first index should be the charge column. + nb_neighbours: Number of neighbours to use when reconstructing the graph representation. Defaults to 8. + features_subset: The subset of latent features on each node that are used as metric dimensions when performing the k-nearest neighbours clustering. Defaults to [0,1,2,3] + dropout: Dropout fraction to use in the RNN. Defaults to 0.5. + embedding_dim: Embedding dimension of the RNN. Defaults to no embedding. + """ + self._hidden_size = hidden_size + self._num_layers = num_layers + self._time_series_columns = time_series_columns + self._nb_neighbors = nb_neighbours + self._features_subset = features_subset + self._embedding_dim = embedding_dim + self._nb_inputs = nb_inputs + + super().__init__(nb_inputs, hidden_size + 5) + + if self._embedding_dim != 0: + self._nb_inputs = self._embedding_dim * nb_inputs + + self._rnn = torch.nn.GRU( + num_layers=self._num_layers, + input_size=self._nb_inputs, + hidden_size=self._hidden_size, + batch_first=True, + dropout=dropout, + ) + self._emb = SinusoidalPosEmb(dim=self._embedding_dim) + + def clean_up_data_object(self, data: Data) -> Data: + """Update the feature names of the data object. + + Args: + data: The input data object. + """ + # old features removing the new_node column + old_features = data.features[0][:-1] + new_features = old_features + [ + "rnn_out_" + str(i) for i in range(self._hidden_size) + ] + data.features = [new_features] * len(data.features) + for i, name in enumerate(old_features): + data[name] = data.x[i] + return data + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass to the GNN.""" + # cutter = data.cutter.cumsum(0)[:-1] + # Optional embedding of the time and charge time series data. + x = data.x + time_series = x[:, self._time_series_columns] + if self._embedding_dim != 0: + time_series = self._emb(time_series * 4096).reshape( + ( + time_series.shape[0], + self._embedding_dim * time_series.shape[-1], + ) + ) + # Create the dom + batch unique splitter from the new_node_col + splitter = x[:, -1].argwhere()[1:].flatten().cpu() + time_series = time_series.tensor_split(splitter) + # apply RNN per DOM irrespective of batch and return the final state. + time_series = torch.nn.utils.rnn.pack_sequence( + time_series, enforce_sorted=False + ) + rnn_out = self._rnn(time_series)[-1][0] + # prepare node level features + charge = data.x[:, self._time_series_columns[0]].tensor_split(splitter) + charge = torch.tensor( + [ + torch.asinh(5 * torch.sum(node_charges) / 5) + for node_charges in charge + ] + ) + batch = data.batch[x[:, -1].bool()] + x = x[x[:, -1].bool()][:, :-1] + x[:, self._time_series_columns[0]] = charge + + # combine the RNN output with the DOM summary features + data.x = torch.hstack([x, rnn_out]) + # correct the batches + data.batch = batch + data = self.clean_up_data_object(data) + # Recompute adjacency + data.edge_index = knn_graph( + x=x[:, self._features_subset], + k=self._nb_neighbors, + batch=batch, + ).to(self.device) + + return data diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index 663664996..53c32deaf 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -273,6 +273,9 @@ def training_step( on_step=False, sync_dist=True, ) + + current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log("lr", current_lr, prog_bar=True, on_step=True) return loss def validation_step( diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index e1ef7956c..d05e8223f 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -1,12 +1,12 @@ """Utility functions for `graphnet.models`.""" -from typing import List, Tuple, Union +from typing import List, Tuple, Any from torch_geometric.nn import knn_graph from torch_geometric.data import Batch import torch from torch import Tensor, LongTensor -from torch_geometric.utils.homophily import homophily +from torch_geometric.utils import homophily def calculate_xyzt_homophily( @@ -59,3 +59,47 @@ def knn_graph_batch(batch: Batch, k: List[int], columns: List[int]) -> Batch: x=data_list[i].x[:, columns], k=k[i] ) return Batch.from_data_list(data_list) + + +def array_to_sequence( + x: Tensor, + batch_idx: LongTensor, + padding_value: Any = 0, + excluding_value: Any = torch.inf, +) -> Tuple[Tensor, Tensor, Tensor]: + """Convert `x` of shape [n,d] into a padded sequence of shape [B, L, D]. + + Where B is the batch size, L is the sequence length and D is the + features for each time step. + + Args: + x: array-like tensor with shape `[n,d]` where `n` is the total number + of pulses in the batch and `d` is the number of node features. + batch_idx: a LongTensor identifying which row in `x` belongs to + which training example. + E.g. `torch_geometric.data.Batch.batch`. + padding_value: The value to use for padding. + excluding_value: This parameter represents a unique value that should + not be present in the input tensor 'x' + Returns: + x: Padded sequence with dimensions [B, L, D]. + mask: A tensor that identifies masked entries in `x`. + E.g. : `masked_entries = x[mask]` + seq_length: A tensor containing the number of pulses in each event. + """ + if torch.any(torch.eq(x, excluding_value)): + raise ValueError( + f"Transformation cannot be made because input tensor " + f"`x` contains at least one element equal to " + f"excluding value {excluding_value}." + ) + + _, seq_length = torch.unique(batch_idx, return_counts=True) + x_list = torch.split(x, seq_length.tolist()) + + x = torch.nn.utils.rnn.pad_sequence( + x_list, batch_first=True, padding_value=excluding_value + ) + mask = torch.ne(x[:, :, 1], excluding_value) + x[~mask] = padding_value + return x, mask, seq_length diff --git a/src/graphnet/pisa/fitting.py b/src/graphnet/pisa/fitting.py index dfcc20a37..5408f9bfc 100644 --- a/src/graphnet/pisa/fitting.py +++ b/src/graphnet/pisa/fitting.py @@ -23,7 +23,7 @@ from pisa.analysis.analysis import Analysis from pisa import ureg -from graphnet.data.sqlite import create_table_and_save_to_sql +from graphnet.data.utilities import create_table_and_save_to_sql mpl.use("pdf") plt.rc("font", family="serif") diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index df7c92e15..fca4a21e0 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -317,3 +317,18 @@ def save_results( model.save_state_dict(path + "/" + tag + "_state_dict.pth") model.save(path + "/" + tag + "_model.pth") Logger().info("Results saved at: \n %s" % path) + + +def save_selection(selection: List[int], file_path: str) -> None: + """Save the list of event numbers to a CSV file. + + Args: + selection: List of event ids. + file_path: File path to save the selection. + """ + assert isinstance( + selection, list + ), "Selection should be a list of integers." + with open(file_path, "w") as f: + f.write(",".join(map(str, selection))) + f.write("\n") diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index a52c91b29..97411bbe5 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -7,7 +7,9 @@ import pandas as pd import sqlite3 -from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql +from graphnet.data.utilities.sqlite_utilities import ( + create_table_and_save_to_sql, +) from graphnet.utilities.logging import Logger diff --git a/src/graphnet/utilities/config/model_config.py b/src/graphnet/utilities/config/model_config.py index 23b4c9b58..18b06def9 100644 --- a/src/graphnet/utilities/config/model_config.py +++ b/src/graphnet/utilities/config/model_config.py @@ -249,6 +249,28 @@ def as_dict(self) -> Dict[str, Dict[str, Any]]: return {self.__class__.__name__: config_dict} + def __repr__(self) -> str: + """Return a string representation of the object.""" + arguments_str = self._format_arguments(self.arguments) + return f"{self.__class__.__name__}(\n{arguments_str}\n)" + + def _format_arguments( + self, arguments: Dict[str, Any], indent: int = 4 + ) -> str: + """Format the arguments dictionary into a string representation.""" + lines = [] + for arg, value in arguments.items(): + if isinstance(value, ModelConfig): + value_str = repr(value) + elif isinstance(value, dict): + value_str = self._format_arguments(value, indent + 4) + else: + value_str = repr(value) + + lines.append(f"{' ' * indent}'{arg}': {value_str},") + + return "{\n" + "\n".join(lines) + "\n" + " " * (indent - 4) + "}" + def save_model_config(init_fn: Callable) -> Callable: """Save the arguments to `__init__` functions as a member `ModelConfig`.""" diff --git a/tests/data/test_dataconverters_and_datasets.py b/tests/data/test_dataconverters_and_datasets.py index 480f11d4d..e1d9e773b 100644 --- a/tests/data/test_dataconverters_and_datasets.py +++ b/tests/data/test_dataconverters_and_datasets.py @@ -11,7 +11,7 @@ from graphnet.constants import TEST_OUTPUT_DIR from graphnet.data.constants import FEATURES, TRUTH from graphnet.data.dataconverter import DataConverter -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, @@ -19,7 +19,6 @@ from graphnet.data.parquet import ParquetDataConverter from graphnet.data.dataset import ParquetDataset, SQLiteDataset from graphnet.data.sqlite import SQLiteDataConverter -from graphnet.data.sqlite.sqlite_dataconverter import is_pulse_map from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter from graphnet.utilities.imports import has_icecube_package from graphnet.models.graphs import KNNGraph @@ -52,17 +51,6 @@ def get_file_path(backend: str) -> str: return path -# Unit test(s) -def test_is_pulsemap_check() -> None: - """Test behaviour of `is_pulsemap_check`.""" - assert is_pulse_map("SplitInIcePulses") is True - assert is_pulse_map("SRTInIcePulses") is True - assert is_pulse_map("InIceDSTPulses") is True - assert is_pulse_map("RTTWOfflinePulses") is True - assert is_pulse_map("truth") is False - assert is_pulse_map("retro") is False - - @pytest.mark.order(1) @pytest.mark.parametrize("backend", ["sqlite", "parquet"]) def test_dataconverter( diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py new file mode 100644 index 000000000..9f8a1b745 --- /dev/null +++ b/tests/data/test_datamodule.py @@ -0,0 +1,334 @@ +"""Unit tests for DataModule.""" + +from copy import deepcopy +import os +from typing import List, Any, Dict, Tuple +import pandas as pd +import sqlite3 +import pytest +from torch.utils.data import SequentialSampler + +from graphnet.constants import EXAMPLE_DATA_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.data.dataset import SQLiteDataset, ParquetDataset +from graphnet.data.datamodule import GraphNeTDataModule +from graphnet.models.detector import IceCubeDeepCore +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodesAsPulses +from graphnet.training.utils import save_selection + + +def extract_all_events_ids( + file_path: str, dataset_kwargs: Dict[str, Any] +) -> List[int]: + """Extract all available event ids.""" + if file_path.endswith(".parquet"): + selection = pd.read_parquet(file_path)["event_id"].to_numpy().tolist() + elif file_path.endswith(".db"): + with sqlite3.connect(file_path) as conn: + query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}' + selection = ( + pd.read_sql(query, conn)["event_no"].to_numpy().tolist() + ) + else: + raise AssertionError( + f"File extension not accepted: {file_path.split('.')[-1]}" + ) + return selection + + +@pytest.fixture +def dataset_ref(request: pytest.FixtureRequest) -> pytest.FixtureRequest: + """Return the dataset reference.""" + return request.param + + +@pytest.fixture +def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: + """Set up the dataset for testing. + + Args: + dataset_ref: The dataset reference. + + Returns: + A tuple with the dataset reference, dataset kwargs, and dataloader kwargs. + """ + # Grab public dataset paths + data_path = ( + f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db" + if dataset_ref is SQLiteDataset + else f"{EXAMPLE_DATA_DIR}/parquet/prometheus/prometheus-events.parquet" + ) + + # Setup basic inputs; can be altered by individual tests + graph_definition = KNNGraph( + detector=IceCubeDeepCore(), + node_definition=NodesAsPulses(), + nb_nearest_neighbours=8, + input_feature_names=FEATURES.DEEPCORE, + ) + + dataset_kwargs = { + "truth_table": "mc_truth", + "pulsemaps": "total", + "truth": TRUTH.PROMETHEUS, + "features": FEATURES.PROMETHEUS, + "path": data_path, + "graph_definition": graph_definition, + } + + dataloader_kwargs = {"batch_size": 2, "num_workers": 1} + + return dataset_ref, dataset_kwargs, dataloader_kwargs + + +@pytest.fixture +def selection() -> List[int]: + """Return a selection.""" + return [1, 2, 3, 4, 5] + + +@pytest.fixture +def file_path(tmpdir: str) -> str: + """Return a file path.""" + return os.path.join(tmpdir, "selection.csv") + + +def test_save_selection(selection: List[int], file_path: str) -> None: + """Test `save_selection` function.""" + save_selection(selection, file_path) + + assert os.path.exists(file_path) + + with open(file_path, "r") as f: + content = f.read() + assert content.strip() == "1,2,3,4,5" + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_without_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Verify GraphNeTDataModule behavior when no test selection is provided. + + Args: + dataset_setup: Tuple with dataset reference, dataset arguments, and dataloader arguments. + + Raises: + Exception: If the test dataloader is accessed without providing a test selection. + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader + + with pytest.raises(Exception): + # should fail because we provided no test selection + test_dataloader = dm.test_dataloader # noqa + # validation loader should have shuffle = False by default + assert isinstance(val_dataloader.sampler, SequentialSampler) + # Should have identical batch_size + assert val_dataloader.batch_size != train_dataloader.batch_size + # Training dataloader should contain more batches + assert len(train_dataloader) > len(val_dataloader) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that selection functionality of DataModule behaves as expected. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset arguments, and dataloader arguments. + + Returns: + None + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + # extract all events + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + test_selection = selection[0:10] + train_val_selection = selection[10:] + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=train_val_selection, + test_selection=test_selection, + ) + + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader + test_dataloader = dm.test_dataloader + + # Check that the training and validation dataloader contains + # the same number of events as was given in the selection. + assert len(train_dataloader.dataset) + len(val_dataloader.dataset) == len(train_val_selection) # type: ignore + # Check that the number of events in the test dataset is equal to the + # number of events given in the selection. + assert len(test_dataloader.dataset) == len(test_selection) # type: ignore + # Training dataloader should have more batches + assert len(train_dataloader) > len(val_dataloader) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_dataloader_args( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that arguments to dataloaders are propagated correctly. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + val_dataloader_kwargs = deepcopy(dataloader_kwargs) + test_dataloader_kwargs = deepcopy(dataloader_kwargs) + + # Setting batch sizes to different values + val_dataloader_kwargs["batch_size"] = 1 + test_dataloader_kwargs["batch_size"] = 2 + dataloader_kwargs["batch_size"] = 3 + + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + validation_dataloader_kwargs=val_dataloader_kwargs, + test_dataloader_kwargs=test_dataloader_kwargs, + ) + + # Check that the resulting dataloaders have the right batch sizes + assert dm.train_dataloader.batch_size == dataloader_kwargs["batch_size"] + assert dm.val_dataloader.batch_size == val_dataloader_kwargs["batch_size"] + assert ( + dm.test_dataloader.batch_size == test_dataloader_kwargs["batch_size"] + ) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_ensemble_dataset_without_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality without selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # Make dataloaders from single dataset + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + dm_single = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=deepcopy(dataset_kwargs), + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # Create dataloaders from multiple datasets + dm_ensemble = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Test that the ensemble dataloaders contain more batches + assert len(dm_single.train_dataloader) < len(dm_ensemble.train_dataloader) + assert len(dm_single.val_dataloader) < len(dm_ensemble.val_dataloader) + + +@pytest.mark.parametrize("dataset_ref", [SQLiteDataset, ParquetDataset]) +def test_ensemble_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality with selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # extract all events + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # pass two datasets but only one selection; should fail: + with pytest.raises(Exception): + _ = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=selection, + ) + + # Pass two datasets and two selections; should work: + selection_1 = selection[0:20] + selection_2 = selection[0:10] + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection_1, selection_2], + ) + n_events_in_dataloaders = len(dm.train_dataloader.dataset) + len(dm.val_dataloader.dataset) # type: ignore + + # Check that the number of events in train/val match + assert n_events_in_dataloaders == len(selection_1) + len(selection_2) + + # Pass two datasets, two selections and two test selections; should work + dm2 = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection, selection], + test_selection=[selection_1, selection_2], + ) + + # Check that the number of events in test dataloaders are correct. + n_events_in_test_dataloaders = len(dm2.test_dataloader.dataset) # type: ignore + assert n_events_in_test_dataloaders == len(selection_1) + len(selection_2) diff --git a/tests/data/test_i3extractor.py b/tests/data/test_i3extractor.py index 3fa19f078..ce40626c0 100644 --- a/tests/data/test_i3extractor.py +++ b/tests/data/test_i3extractor.py @@ -1,6 +1,6 @@ -"""Unit tests for I3Extractor class.""" +"""Unit tests for I3Extractor.""" -from graphnet.data.extractors import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCube86, I3TruthExtractor, I3RetroExtractor, diff --git a/tests/data/test_i3genericextractor.py b/tests/data/test_i3genericextractor.py deleted file mode 100644 index 314fa5f44..000000000 --- a/tests/data/test_i3genericextractor.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Unit tests for I3GenericExtractor class.""" - -import os - -import numpy as np - -import graphnet.constants -from graphnet.data.extractors import ( - I3FeatureExtractorIceCube86, - I3TruthExtractor, - I3GenericExtractor, -) -from graphnet.utilities.imports import has_icecube_package - -if has_icecube_package(): - from icecube import dataio # pyright: reportMissingImports=false - -# Global variable(s) -TEST_DATA_DIR = os.path.join( - graphnet.constants.TEST_DATA_DIR, "i3", "oscNext_genie_level7_v02" -) -FILE_NAME = "oscNext_genie_level7_v02_first_5_frames" -GCD_FILE = ( - "GeoCalibDetectorStatus_AVG_55697-57531_PASS2_SPE_withScaledNoise.i3.gz" -) - - -# Unit test(s) -def test_i3genericextractor(test_data_dir: str = TEST_DATA_DIR) -> None: - """Test the implementation of `I3GenericExtractor`.""" - # Constants(s) - mc_tree = "I3MCTree" - pulse_series = "SRTInIcePulses" - - # Constructor I3Extractor instance(s) - generic_extractor = I3GenericExtractor(keys=[mc_tree, pulse_series]) - truth_extractor = I3TruthExtractor() - feature_extractor = I3FeatureExtractorIceCube86(pulse_series) - - i3_file = os.path.join(test_data_dir, FILE_NAME) + ".i3.gz" - gcd_file = os.path.join(test_data_dir, GCD_FILE) - - generic_extractor.set_files(i3_file, gcd_file) - truth_extractor.set_files(i3_file, gcd_file) - feature_extractor.set_files(i3_file, gcd_file) - - i3_file_io = dataio.I3File(i3_file, "r") - ix_test = 5 - while i3_file_io.more(): - try: - frame = i3_file_io.pop_physics() - except: # noqa: E722 - continue - - generic_data = generic_extractor(frame) - truth_data = truth_extractor(frame) - feature_data = feature_extractor(frame) - - if ix_test == 5: - print(list(generic_data[pulse_series].keys())) - print(list(truth_data.keys())) - print(list(feature_data.keys())) - - # Truth vs. generic - key_pairs = [ - ("energy", "energy"), - ("zenith", "dir__zenith"), - ("azimuth", "dir__azimuth"), - ("pid", "pdg_encoding"), - ] - - for truth_key, generic_key in key_pairs: - assert ( - truth_data[truth_key] - == generic_data[f"{mc_tree}__primaries"][generic_key][0] - ) - - # Reco vs. generic - key_pairs = [ - ("charge", "charge"), - ("dom_time", "time"), - ("dom_x", "position__x"), - ("dom_y", "position__y"), - ("dom_z", "position__z"), - ("width", "width"), - ("pmt_area", "area"), - ("rde", "relative_dom_eff"), - ] - - for reco_key, generic_key in key_pairs: - assert np.allclose( - feature_data[reco_key], generic_data[pulse_series][generic_key] - ) - - ix_test -= 1 - if ix_test == 0: - break diff --git a/tests/deployment/queso_test.py b/tests/deployment/queso_test.py index d1258ed89..cf761bf52 100644 --- a/tests/deployment/queso_test.py +++ b/tests/deployment/queso_test.py @@ -8,7 +8,7 @@ import pytest from graphnet.data.constants import FEATURES -from graphnet.data.extractors.i3featureextractor import ( +from graphnet.data.extractors.icecube import ( I3FeatureExtractorIceCubeUpgrade, ) from graphnet.constants import ( @@ -23,7 +23,6 @@ from graphnet.deployment.i3modules import ( I3InferenceModule, - GraphNeTI3Module, I3PulseCleanerModule, ) @@ -34,7 +33,7 @@ def apply_to_files( i3_files: List[str], gcd_file: str, output_folder: str, - modules: Sequence["GraphNeTI3Module"], + modules: Sequence["I3InferenceModule"], ) -> None: """Will start an IceTray read/write chain with graphnet modules. diff --git a/tests/utilities/test_model_config.py b/tests/utilities/test_model_config.py index 1df87b1c8..f9ae49dce 100644 --- a/tests/utilities/test_model_config.py +++ b/tests/utilities/test_model_config.py @@ -122,8 +122,9 @@ def test_complete_model_config(path: str = "/tmp/complete_model.yml") -> None: "transform_prediction_and_target" ](x_) ) - - assert repr(constructed_model) == repr(model) + model.set_verbose_print_recursively(False) + constructed_model.set_verbose_print_recursively(False) + assert repr(model) == repr(constructed_model) @pytest.mark.run(after="test_complete_model_config")