From 64667c2fa13892ed897f54cccf793aa79c952802 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20S=C3=B8gaard?= Date: Thu, 14 Sep 2023 03:54:02 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20Aske-Ros?= =?UTF-8?q?ted/graphnet@058ce59bcd2fedb22b27d3d545be82b0d5aa47bd=20?= =?UTF-8?q?=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/graphnet/data/constants.html | 2 +- _modules/graphnet/data/dataconverter.html | 2 +- _modules/graphnet/data/dataloader.html | 457 ++++++ _modules/graphnet/data/dataset/dataset.html | 1084 ++++++++++++++ .../data/dataset/parquet/parquet_dataset.html | 500 +++++++ .../data/dataset/sqlite/sqlite_dataset.html | 515 +++++++ .../sqlite/sqlite_dataset_perturbed.html | 515 +++++++ .../graphnet/data/extractors/i3extractor.html | 2 +- .../data/extractors/i3featureextractor.html | 2 +- .../data/extractors/i3genericextractor.html | 2 +- .../extractors/i3hybridrecoextractor.html | 2 +- .../extractors/i3ntmuonlabelsextractor.html | 2 +- .../data/extractors/i3particleextractor.html | 2 +- .../data/extractors/i3pisaextractor.html | 2 +- .../data/extractors/i3quesoextractor.html | 2 +- .../data/extractors/i3retroextractor.html | 2 +- .../data/extractors/i3splinempeextractor.html | 2 +- .../data/extractors/i3truthextractor.html | 2 +- .../data/extractors/i3tumextractor.html | 2 +- .../extractors/utilities/collections.html | 2 +- .../data/extractors/utilities/frames.html | 2 +- .../data/extractors/utilities/types.html | 2 +- .../data/parquet/parquet_dataconverter.html | 2 +- _modules/graphnet/data/pipeline.html | 593 ++++++++ .../data/sqlite/sqlite_dataconverter.html | 2 +- .../data/sqlite/sqlite_utilities.html | 2 +- .../data/utilities/parquet_to_sqlite.html | 2 +- _modules/graphnet/data/utilities/random.html | 2 +- .../utilities/string_selection_resolver.html | 2 +- .../deployment/i3modules/graphnet_module.html | 817 +++++++++++ _modules/graphnet/models/coarsening.html | 711 +++++++++ .../graphnet/models/components/layers.html | 579 ++++++++ _modules/graphnet/models/components/pool.html | 656 +++++++++ .../graphnet/models/detector/detector.html | 421 ++++++ .../graphnet/models/detector/icecube.html | 528 +++++++ .../graphnet/models/detector/prometheus.html | 395 +++++ _modules/graphnet/models/gnn/convnet.html | 486 +++++++ _modules/graphnet/models/gnn/dynedge.html | 693 +++++++++ .../graphnet/models/gnn/dynedge_jinst.html | 521 +++++++ .../models/gnn/dynedge_kaggle_tito.html | 618 ++++++++ _modules/graphnet/models/gnn/gnn.html | 403 +++++ .../graphnet/models/graphs/edges/edges.html | 563 +++++++ .../models/graphs/graph_definition.html | 634 ++++++++ _modules/graphnet/models/graphs/graphs.html | 410 ++++++ .../graphnet/models/graphs/nodes/nodes.html | 444 ++++++ _modules/graphnet/models/model.html | 726 ++++++++++ _modules/graphnet/models/standard_model.html | 604 ++++++++ .../graphnet/models/task/classification.html | 411 ++++++ .../graphnet/models/task/reconstruction.html | 609 ++++++++ _modules/graphnet/models/task/task.html | 688 +++++++++ _modules/graphnet/models/utils.html | 430 ++++++ _modules/graphnet/pisa/fitting.html | 2 +- _modules/graphnet/pisa/plotting.html | 2 +- _modules/graphnet/training/callbacks.html | 544 +++++++ _modules/graphnet/training/labels.html | 436 ++++++ .../graphnet/training/loss_functions.html | 859 +++++++++++ _modules/graphnet/training/utils.html | 656 +++++++++ .../graphnet/training/weight_fitting.html | 2 +- _modules/graphnet/utilities/argparse.html | 2 +- .../utilities/config/base_config.html | 449 ++++++ .../utilities/config/configurable.html | 408 ++++++ .../utilities/config/dataset_config.html | 585 ++++++++ .../utilities/config/model_config.html | 654 +++++++++ .../graphnet/utilities/config/parsing.html | 475 ++++++ .../utilities/config/training_config.html | 378 +++++ _modules/graphnet/utilities/filesys.html | 2 +- _modules/graphnet/utilities/imports.html | 2 +- _modules/graphnet/utilities/logging.html | 2 +- _modules/graphnet/utilities/maths.html | 371 +++++ _modules/index.html | 41 +- about.html | 2 +- api/graphnet.constants.html | 2 +- api/graphnet.data.constants.html | 2 +- api/graphnet.data.dataconverter.html | 2 +- api/graphnet.data.dataloader.html | 129 +- api/graphnet.data.dataset.dataset.html | 330 ++++- api/graphnet.data.dataset.html | 16 +- api/graphnet.data.dataset.parquet.html | 12 +- ....data.dataset.parquet.parquet_dataset.html | 116 +- api/graphnet.data.dataset.sqlite.html | 17 +- ...et.data.dataset.sqlite.sqlite_dataset.html | 116 +- ...taset.sqlite.sqlite_dataset_perturbed.html | 84 +- api/graphnet.data.extractors.html | 2 +- api/graphnet.data.extractors.i3extractor.html | 2 +- ...et.data.extractors.i3featureextractor.html | 2 +- ...et.data.extractors.i3genericextractor.html | 2 +- ...data.extractors.i3hybridrecoextractor.html | 2 +- ...ta.extractors.i3ntmuonlabelsextractor.html | 2 +- ...t.data.extractors.i3particleextractor.html | 2 +- ...phnet.data.extractors.i3pisaextractor.html | 2 +- ...hnet.data.extractors.i3quesoextractor.html | 2 +- ...hnet.data.extractors.i3retroextractor.html | 2 +- ....data.extractors.i3splinempeextractor.html | 2 +- ...hnet.data.extractors.i3truthextractor.html | 2 +- ...aphnet.data.extractors.i3tumextractor.html | 2 +- ...data.extractors.utilities.collections.html | 2 +- ...hnet.data.extractors.utilities.frames.html | 2 +- api/graphnet.data.extractors.utilities.html | 2 +- ...phnet.data.extractors.utilities.types.html | 2 +- api/graphnet.data.html | 14 +- api/graphnet.data.parquet.html | 2 +- ...et.data.parquet.parquet_dataconverter.html | 2 +- api/graphnet.data.pipeline.html | 55 +- api/graphnet.data.sqlite.html | 2 +- ...hnet.data.sqlite.sqlite_dataconverter.html | 2 +- ...graphnet.data.sqlite.sqlite_utilities.html | 2 +- api/graphnet.data.utilities.html | 2 +- ...hnet.data.utilities.parquet_to_sqlite.html | 2 +- api/graphnet.data.utilities.random.html | 2 +- ...a.utilities.string_selection_resolver.html | 4 +- api/graphnet.deployment.html | 2 +- ...raphnet.deployment.i3modules.deployer.html | 2 +- ....deployment.i3modules.graphnet_module.html | 135 +- api/graphnet.deployment.i3modules.html | 9 +- api/graphnet.html | 2 +- api/graphnet.models.coarsening.html | 226 ++- api/graphnet.models.components.html | 28 +- api/graphnet.models.components.layers.html | 253 +++- api/graphnet.models.components.pool.html | 339 ++++- api/graphnet.models.detector.detector.html | 89 +- api/graphnet.models.detector.html | 25 +- api/graphnet.models.detector.icecube.html | 197 ++- api/graphnet.models.detector.prometheus.html | 62 +- api/graphnet.models.gnn.convnet.html | 76 +- api/graphnet.models.gnn.dynedge.html | 98 +- api/graphnet.models.gnn.dynedge_jinst.html | 73 +- ...aphnet.models.gnn.dynedge_kaggle_tito.html | 83 +- api/graphnet.models.gnn.gnn.html | 103 +- api/graphnet.models.gnn.html | 32 +- api/graphnet.models.graphs.edges.edges.html | 169 ++- api/graphnet.models.graphs.edges.html | 18 +- ...aphnet.models.graphs.graph_definition.html | 93 +- api/graphnet.models.graphs.graphs.html | 50 +- api/graphnet.models.graphs.html | 20 +- api/graphnet.models.graphs.nodes.html | 16 +- api/graphnet.models.graphs.nodes.nodes.html | 132 +- api/graphnet.models.html | 39 +- api/graphnet.models.model.html | 305 +++- api/graphnet.models.standard_model.html | 348 ++++- api/graphnet.models.task.classification.html | 262 +++- api/graphnet.models.task.html | 36 +- api/graphnet.models.task.reconstruction.html | 1290 ++++++++++++++++- api/graphnet.models.task.task.html | 302 +++- api/graphnet.models.utils.html | 110 +- api/graphnet.pisa.fitting.html | 2 +- api/graphnet.pisa.html | 2 +- api/graphnet.pisa.plotting.html | 2 +- api/graphnet.training.callbacks.html | 277 +++- api/graphnet.training.html | 38 +- api/graphnet.training.labels.html | 93 +- api/graphnet.training.loss_functions.html | 488 ++++++- api/graphnet.training.utils.html | 191 ++- api/graphnet.training.weight_fitting.html | 2 +- api/graphnet.utilities.argparse.html | 2 +- ...graphnet.utilities.config.base_config.html | 177 ++- ...raphnet.utilities.config.configurable.html | 105 +- ...phnet.utilities.config.dataset_config.html | 438 +++++- api/graphnet.utilities.config.html | 45 +- ...raphnet.utilities.config.model_config.html | 177 ++- api/graphnet.utilities.config.parsing.html | 165 ++- ...hnet.utilities.config.training_config.html | 147 +- api/graphnet.utilities.decorators.html | 2 +- api/graphnet.utilities.filesys.html | 2 +- api/graphnet.utilities.html | 7 +- api/graphnet.utilities.imports.html | 2 +- api/graphnet.utilities.logging.html | 2 +- api/graphnet.utilities.maths.html | 41 +- api/modules.html | 2 +- contribute.html | 2 +- genindex.html | 1174 ++++++++++++++- index.html | 2 +- install.html | 2 +- objects.inv | Bin 3520 -> 6210 bytes py-modindex.html | 257 +++- search.html | 2 +- searchindex.js | 2 +- sitemap.xml | 2 +- 177 files changed, 31425 insertions(+), 329 deletions(-) create mode 100644 _modules/graphnet/data/dataloader.html create mode 100644 _modules/graphnet/data/dataset/dataset.html create mode 100644 _modules/graphnet/data/dataset/parquet/parquet_dataset.html create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset.html create mode 100644 _modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html create mode 100644 _modules/graphnet/data/pipeline.html create mode 100644 _modules/graphnet/deployment/i3modules/graphnet_module.html create mode 100644 _modules/graphnet/models/coarsening.html create mode 100644 _modules/graphnet/models/components/layers.html create mode 100644 _modules/graphnet/models/components/pool.html create mode 100644 _modules/graphnet/models/detector/detector.html create mode 100644 _modules/graphnet/models/detector/icecube.html create mode 100644 _modules/graphnet/models/detector/prometheus.html create mode 100644 _modules/graphnet/models/gnn/convnet.html create mode 100644 _modules/graphnet/models/gnn/dynedge.html create mode 100644 _modules/graphnet/models/gnn/dynedge_jinst.html create mode 100644 _modules/graphnet/models/gnn/dynedge_kaggle_tito.html create mode 100644 _modules/graphnet/models/gnn/gnn.html create mode 100644 _modules/graphnet/models/graphs/edges/edges.html create mode 100644 _modules/graphnet/models/graphs/graph_definition.html create mode 100644 _modules/graphnet/models/graphs/graphs.html create mode 100644 _modules/graphnet/models/graphs/nodes/nodes.html create mode 100644 _modules/graphnet/models/model.html create mode 100644 _modules/graphnet/models/standard_model.html create mode 100644 _modules/graphnet/models/task/classification.html create mode 100644 _modules/graphnet/models/task/reconstruction.html create mode 100644 _modules/graphnet/models/task/task.html create mode 100644 _modules/graphnet/models/utils.html create mode 100644 _modules/graphnet/training/callbacks.html create mode 100644 _modules/graphnet/training/labels.html create mode 100644 _modules/graphnet/training/loss_functions.html create mode 100644 _modules/graphnet/training/utils.html create mode 100644 _modules/graphnet/utilities/config/base_config.html create mode 100644 _modules/graphnet/utilities/config/configurable.html create mode 100644 _modules/graphnet/utilities/config/dataset_config.html create mode 100644 _modules/graphnet/utilities/config/model_config.html create mode 100644 _modules/graphnet/utilities/config/parsing.html create mode 100644 _modules/graphnet/utilities/config/training_config.html create mode 100644 _modules/graphnet/utilities/maths.html diff --git a/_modules/graphnet/data/constants.html b/_modules/graphnet/data/constants.html index cd439b37c..b6526af39 100644 --- a/_modules/graphnet/data/constants.html +++ b/_modules/graphnet/data/constants.html @@ -436,7 +436,7 @@

Source code for graphnet.dat Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/dataconverter.html b/_modules/graphnet/data/dataconverter.html index aa78de84b..9be52a7d9 100644 --- a/_modules/graphnet/data/dataconverter.html +++ b/_modules/graphnet/data/dataconverter.html @@ -938,7 +938,7 @@

Source code for graphnet Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/dataloader.html b/_modules/graphnet/data/dataloader.html new file mode 100644 index 000000000..4de320872 --- /dev/null +++ b/_modules/graphnet/data/dataloader.html @@ -0,0 +1,457 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.dataloader — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.dataloader

+"""Base `Dataloader` class(es) used in `graphnet`."""
+
+from typing import Any, Callable, Dict, List, Union
+
+import torch.utils.data
+from torch_geometric.data import Batch, Data
+
+from graphnet.data.dataset import Dataset
+from graphnet.utilities.config import DatasetConfig
+
+
+
+[docs] +def collate_fn(graphs: List[Data]) -> Batch: + """Remove graphs with less than two DOM hits. + + Should not occur in "production. + """ + graphs = [g for g in graphs if g.n_pulses > 1] + return Batch.from_data_list(graphs)
+ + + +
+[docs] +def do_shuffle(selection_name: str) -> bool: + """Check whether to shuffle selection with name `selection_name`.""" + return "train" in selection_name.lower()
+ + + +
+[docs] +class DataLoader(torch.utils.data.DataLoader): + """Class for loading data from a `Dataset`.""" + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 10, + persistent_workers: bool = True, + collate_fn: Callable = collate_fn, + prefetch_factor: int = 2, + **kwargs: Any, + ) -> None: + """Construct `DataLoader`.""" + # Base class constructor + super().__init__( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + **kwargs, + ) + +
+[docs] + @classmethod + def from_dataset_config( + cls, + config: DatasetConfig, + **kwargs: Any, + ) -> Union["DataLoader", Dict[str, "DataLoader"]]: + """Construct `DataLoader`s based on selections in `DatasetConfig`.""" + if isinstance(config.selection, dict): + assert "shuffle" not in kwargs, ( + "When passing a `DatasetConfig` with multiple selections, " + "`shuffle` is automatically inferred from the selection name, " + "and thus should not specified as an argument." + ) + datasets = Dataset.from_config(config) + assert isinstance(datasets, dict) + data_loaders: Dict[str, DataLoader] = {} + for name, dataset in datasets.items(): + data_loaders[name] = cls( + dataset, + shuffle=do_shuffle(name), + **kwargs, + ) + + return data_loaders + + else: + assert "shuffle" in kwargs, ( + "When passing a `DatasetConfig` with a single selections, you " + "need to specify `shuffle` as an argument." + ) + dataset = Dataset.from_config(config) + assert isinstance(dataset, Dataset) + return cls(dataset, **kwargs)
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/dataset.html b/_modules/graphnet/data/dataset/dataset.html new file mode 100644 index 000000000..f2ee86f4d --- /dev/null +++ b/_modules/graphnet/data/dataset/dataset.html @@ -0,0 +1,1084 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.dataset.dataset — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.dataset.dataset

+"""Base :py:class:`Dataset` class(es) used in GraphNeT."""
+
+from copy import deepcopy
+from abc import ABC, abstractmethod
+from typing import (
+    cast,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    Iterable,
+    Type,
+)
+
+import numpy as np
+import torch
+from torch_geometric.data import Data
+
+from graphnet.constants import GRAPHNET_ROOT_DIR
+from graphnet.data.utilities.string_selection_resolver import (
+    StringSelectionResolver,
+)
+from graphnet.training.labels import Label
+from graphnet.utilities.config import (
+    Configurable,
+    DatasetConfig,
+    save_dataset_config,
+)
+from graphnet.utilities.config.parsing import traverse_and_apply
+from graphnet.utilities.logging import Logger
+from graphnet.models.graphs import GraphDefinition
+
+from graphnet.utilities.config.parsing import (
+    get_all_grapnet_classes,
+)
+
+
+
+[docs] +class ColumnMissingException(Exception): + """Exception to indicate a missing column in a dataset."""
+ + + +
+[docs] +def load_module(class_name: str) -> Type: + """Load graphnet module from string name. + + Args: + class_name: name of class + + Returns: + graphnet module. + """ + # Get a lookup for all classes in `graphnet` + import graphnet.data + import graphnet.models + import graphnet.training + + namespace_classes = get_all_grapnet_classes( + graphnet.data, graphnet.models, graphnet.training + ) + return namespace_classes[class_name]
+ + + +
+[docs] +def parse_graph_definition(cfg: dict) -> GraphDefinition: + """Construct GraphDefinition from DatasetConfig.""" + assert cfg["graph_definition"] is not None + + args = cfg["graph_definition"]["arguments"] + classes = {} + for arg in args.keys(): + if isinstance(args[arg], dict): + if "class_name" in args[arg].keys(): + classes[arg] = load_module(args[arg]["class_name"])( + **args[arg]["arguments"] + ) + if arg == "dtype": + args[arg] = eval(args[arg]) # converts string to class + + new_cfg = deepcopy(args) + new_cfg.update(classes) + graph_definition = load_module(cfg["graph_definition"]["class_name"])( + **new_cfg + ) + return graph_definition
+ + + +
+[docs] +class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC): + """Base Dataset class for reading from any intermediate file format.""" + + # Class method(s) +
+[docs] + @classmethod + def from_config( # type: ignore[override] + cls, + source: Union[DatasetConfig, str], + ) -> Union[ + "Dataset", + "EnsembleDataset", + Dict[str, "Dataset"], + Dict[str, "EnsembleDataset"], + ]: + """Construct `Dataset` instance from `source` configuration.""" + if isinstance(source, str): + source = DatasetConfig.load(source) + + assert isinstance(source, DatasetConfig), ( + f"Argument `source` of type ({type(source)}) is not a " + "`DatasetConfig`" + ) + + assert ( + "graph_definition" in source.dict().keys() + ), "`DatasetConfig` incompatible with current GraphNeT version." + + # Parse set of `selection``. + if isinstance(source.selection, dict): + return cls._construct_datasets_from_dict(source) + elif ( + isinstance(source.selection, list) + and len(source.selection) + and isinstance(source.selection[0], str) + ): + return cls._construct_dataset_from_list_of_strings(source) + + cfg = source.dict() + if cfg["graph_definition"] is not None: + cfg["graph_definition"] = parse_graph_definition(cfg) + return source._dataset_class(**cfg)
+ + +
+[docs] + @classmethod + def concatenate( + cls, + datasets: List["Dataset"], + ) -> "EnsembleDataset": + """Concatenate multiple `Dataset`s into one instance.""" + return EnsembleDataset(datasets)
+ + + @classmethod + def _construct_datasets_from_dict( + cls, config: DatasetConfig + ) -> Dict[str, "Dataset"]: + """Construct `Dataset` for each entry in dict `self.selection`.""" + assert isinstance(config.selection, dict) + datasets: Dict[str, "Dataset"] = {} + selections: Dict[str, Union[str, List]] = deepcopy(config.selection) + for key, selection in selections.items(): + config.selection = selection + dataset = Dataset.from_config(config) + assert isinstance(dataset, (Dataset, EnsembleDataset)) + datasets[key] = dataset + + # Reset `selections`. + config.selection = selections + + return datasets + + @classmethod + def _construct_dataset_from_list_of_strings( + cls, config: DatasetConfig + ) -> "Dataset": + """Construct `Dataset` for each entry in list `self.selection`.""" + assert isinstance(config.selection, list) + datasets: List["Dataset"] = [] + selections: List[str] = deepcopy(cast(List[str], config.selection)) + for selection in selections: + config.selection = selection + dataset = Dataset.from_config(config) + assert isinstance(dataset, Dataset) + datasets.append(dataset) + + # Reset `selections`. + config.selection = selections + + return cls.concatenate(datasets) + + @classmethod + def _resolve_graphnet_paths( + cls, path: Union[str, List[str]] + ) -> Union[str, List[str]]: + if isinstance(path, list): + return [cast(str, cls._resolve_graphnet_paths(p)) for p in path] + + assert isinstance(path, str) + return ( + path.replace("$graphnet", GRAPHNET_ROOT_DIR) + .replace("$GRAPHNET", GRAPHNET_ROOT_DIR) + .replace("${graphnet}", GRAPHNET_ROOT_DIR) + .replace("${GRAPHNET}", GRAPHNET_ROOT_DIR) + ) + + @save_dataset_config + def __init__( + self, + path: Union[str, List[str]], + graph_definition: GraphDefinition, + pulsemaps: Union[str, List[str]], + features: List[str], + truth: List[str], + *, + node_truth: Optional[List[str]] = None, + index_column: str = "event_no", + truth_table: str = "truth", + node_truth_table: Optional[str] = None, + string_selection: Optional[List[int]] = None, + selection: Optional[Union[str, List[int], List[List[int]]]] = None, + dtype: torch.dtype = torch.float32, + loss_weight_table: Optional[str] = None, + loss_weight_column: Optional[str] = None, + loss_weight_default_value: Optional[float] = None, + seed: Optional[int] = None, + ): + """Construct Dataset. + + Args: + path: Path to the file(s) from which this `Dataset` should read. + pulsemaps: Name(s) of the pulse map series that should be used to + construct the nodes on the individual graph objects, and their + features. Multiple pulse series maps can be used, e.g., when + different DOM types are stored in different maps. + features: List of columns in the input files that should be used as + node features on the graph objects. + truth: List of event-level columns in the input files that should + be used added as attributes on the graph objects. + node_truth: List of node-level columns in the input files that + should be used added as attributes on the graph objects. + index_column: Name of the column in the input files that contains + unique indicies to identify and map events across tables. + truth_table: Name of the table containing event-level truth + information. + node_truth_table: Name of the table containing node-level truth + information. + string_selection: Subset of strings for which data should be read + and used to construct graph objects. Defaults to None, meaning + all strings for which data exists are used. + selection: The events that should be read. This can be given either + as list of indicies (in `index_column`); or a string-based + selection used to query the `Dataset` for events passing the + selection. Defaults to None, meaning that all events in the + input files are read. + dtype: Type of the feature tensor on the graph objects returned. + loss_weight_table: Name of the table containing per-event loss + weights. + loss_weight_column: Name of the column in `loss_weight_table` + containing per-event loss weights. This is also the name of the + corresponding attribute assigned to the graph object. + loss_weight_default_value: Default per-event loss weight. + NOTE: This default value is only applied when + `loss_weight_table` and `loss_weight_column` are specified, and + in this case to events with no value in the corresponding + table/column. That is, if no per-event loss weight table/column + is provided, this value is ignored. Defaults to None. + seed: Random number generator seed, used for selecting a random + subset of events when resolving a string-based selection (e.g., + `"10000 random events ~ event_no % 5 > 0"` or `"20% random + events ~ event_no % 5 > 0"`). + graph_definition: Method that defines the graph representation. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if isinstance(pulsemaps, str): + pulsemaps = [pulsemaps] + + assert isinstance(features, (list, tuple)) + assert isinstance(truth, (list, tuple)) + + # Resolve reference to `$GRAPHNET` in path(s) + path = self._resolve_graphnet_paths(path) + + # Member variable(s) + self._path = path + self._selection = None + self._pulsemaps = pulsemaps + self._features = [index_column] + features + self._truth = [index_column] + truth + self._index_column = index_column + self._truth_table = truth_table + self._loss_weight_default_value = loss_weight_default_value + self._graph_definition = graph_definition + + if node_truth is not None: + assert isinstance(node_truth_table, str) + if isinstance(node_truth, str): + node_truth = [node_truth] + + self._node_truth = node_truth + self._node_truth_table = node_truth_table + + if string_selection is not None: + self.warning( + ( + "String selection detected.\n " + f"Accepted strings: {string_selection}\n " + "All other strings are ignored!" + ) + ) + if isinstance(string_selection, int): + string_selection = [string_selection] + + self._string_selection = string_selection + + self._selection = None + if self._string_selection: + self._selection = f"string in {str(tuple(self._string_selection))}" + + self._loss_weight_column = loss_weight_column + self._loss_weight_table = loss_weight_table + if (self._loss_weight_table is None) and ( + self._loss_weight_column is not None + ): + self.warning("Error: no loss weight table specified") + assert isinstance(self._loss_weight_table, str) + if (self._loss_weight_table is not None) and ( + self._loss_weight_column is None + ): + self.warning("Error: no loss weight column specified") + assert isinstance(self._loss_weight_column, str) + + self._dtype = dtype + + self._label_fns: Dict[str, Callable[[Data], Any]] = {} + + self._string_selection_resolver = StringSelectionResolver( + self, + index_column=index_column, + seed=seed, + ) + + # Implementation-specific initialisation. + self._init() + + # Set unique indices + self._indices: Union[List[int], List[List[int]]] + if selection is None: + self._indices = self._get_all_indices() + elif isinstance(selection, str): + self._indices = self._resolve_string_selection_to_indices( + selection + ) + else: + self._indices = selection + + # Purely internal member variables + self._missing_variables: Dict[str, List[str]] = {} + self._remove_missing_columns() + + # Implementation-specific post-init code. + self._post_init() + + # Properties + @property + def path(self) -> Union[str, List[str]]: + """Path to the file(s) from which this `Dataset` reads.""" + return self._path + + @property + def truth_table(self) -> str: + """Name of the table containing event-level truth information.""" + return self._truth_table + + # Abstract method(s) + @abstractmethod + def _init(self) -> None: + """Set internal representation needed to read data from input file.""" + + def _post_init(self) -> None: + """Implementation-specific code executed after the main constructor.""" + + @abstractmethod + def _get_all_indices(self) -> List[int]: + """Return a list of all available values in `self._index_column`.""" + + @abstractmethod + def _get_event_index( + self, sequential_index: Optional[int] + ) -> Optional[int]: + """Return the event index corresponding to a `sequential_index`.""" + +
+[docs] + @abstractmethod + def query_table( + self, + table: str, + columns: Union[List[str], str], + sequential_index: Optional[int] = None, + selection: Optional[str] = None, + ) -> List[Tuple[Any, ...]]: + """Query a table at a specific index, optionally with some selection. + + Args: + table: Table to be queried. + columns: Columns to read out. + sequential_index: Sequentially numbered index + (i.e. in [0,len(self))) of the event to query. This _may_ + differ from the indexation used in `self._indices`. If no value + is provided, the entire column is returned. + selection: Selection to be imposed before reading out data. + Defaults to None. + + Returns: + List of tuples containing the values in `columns`. If the `table` + contains only scalar data for `columns`, a list of length 1 is + returned + + Raises: + ColumnMissingException: If one or more element in `columns` is not + present in `table`. + """
+ + + # Public method(s) +
+[docs] + def add_label( + self, fn: Callable[[Data], Any], key: Optional[str] = None + ) -> None: + """Add custom graph label define using function `fn`.""" + if isinstance(fn, Label): + key = fn.key + assert isinstance( + key, str + ), "Please specify a key for the custom label to be added." + assert ( + key not in self._label_fns + ), f"A custom label {key} has already been defined." + self._label_fns[key] = fn
+ + + def __len__(self) -> int: + """Return number of graphs in `Dataset`.""" + return len(self._indices) + + def __getitem__(self, sequential_index: int) -> Data: + """Return graph `Data` object at `index`.""" + if not (0 <= sequential_index < len(self)): + raise IndexError( + f"Index {sequential_index} not in range [0, {len(self) - 1}]" + ) + features, truth, node_truth, loss_weight = self._query( + sequential_index + ) + graph = self._create_graph(features, truth, node_truth, loss_weight) + return graph + + # Internal method(s) + def _resolve_string_selection_to_indices( + self, selection: str + ) -> List[int]: + """Resolve selection as string to list of indices. + + Selections are expected to have pandas.DataFrame.query-compatible + syntax, e.g., ``` "event_no % 5 > 0" ``` Selections may also specify a + fixed number of events to randomly sample, e.g., ``` "10000 random + events ~ event_no % 5 > 0" "20% random events ~ event_no % 5 > 0" ``` + """ + return self._string_selection_resolver.resolve(selection) + + def _remove_missing_columns(self) -> None: + """Remove columns that are not present in the input file. + + Columns are removed from `self._features` and `self._truth`. + """ + # Check if table is completely empty + if len(self) == 0: + self.warning("Dataset is empty.") + return + + # Find missing features + missing_features_set = set(self._features) + for pulsemap in self._pulsemaps: + missing = self._check_missing_columns(self._features, pulsemap) + missing_features_set = missing_features_set.intersection(missing) + + missing_features = list(missing_features_set) + + # Find missing truth variables + missing_truth_variables = self._check_missing_columns( + self._truth, self._truth_table + ) + + # Remove missing features + if missing_features: + self.warning( + "Removing the following (missing) features: " + + ", ".join(missing_features) + ) + for missing_feature in missing_features: + self._features.remove(missing_feature) + + # Remove missing truth variables + if missing_truth_variables: + self.warning( + ( + "Removing the following (missing) truth variables: " + + ", ".join(missing_truth_variables) + ) + ) + for missing_truth_variable in missing_truth_variables: + self._truth.remove(missing_truth_variable) + + def _check_missing_columns( + self, + columns: List[str], + table: str, + ) -> List[str]: + """Return a list missing columns in `table`.""" + for column in columns: + try: + self.query_table(table, [column], 0) + except ColumnMissingException: + if table not in self._missing_variables: + self._missing_variables[table] = [] + self._missing_variables[table].append(column) + except IndexError: + self.warning(f"Dataset contains no entries for {column}") + + return self._missing_variables.get(table, []) + + def _query( + self, sequential_index: int + ) -> Tuple[ + List[Tuple[float, ...]], + Tuple[Any, ...], + Optional[List[Tuple[Any, ...]]], + Optional[float], + ]: + """Query file for event features and truth information. + + The returned lists have lengths corresponding to the number of pulses + in the event. Their constituent tuples have lengths corresponding to + the number of features/attributes in each output + + Args: + sequential_index: Sequentially numbered index + (i.e. in [0,len(self))) of the event to query. This _may_ + differ from the indexation used in `self._indices`. + + Returns: + Tuple containing pulse-level event features; event-level truth + information; pulse-level truth information; and event-level + loss weights, respectively. + """ + features = [] + for pulsemap in self._pulsemaps: + features_pulsemap = self.query_table( + pulsemap, self._features, sequential_index, self._selection + ) + features.extend(features_pulsemap) + + truth: Tuple[Any, ...] = self.query_table( + self._truth_table, self._truth, sequential_index + )[0] + if self._node_truth: + assert self._node_truth_table is not None + node_truth = self.query_table( + self._node_truth_table, + self._node_truth, + sequential_index, + self._selection, + ) + else: + node_truth = None + + loss_weight: Optional[float] = None # Default + if self._loss_weight_column is not None: + assert self._loss_weight_table is not None + loss_weight_list = self.query_table( + self._loss_weight_table, + self._loss_weight_column, + sequential_index, + ) + if len(loss_weight_list): + loss_weight = loss_weight_list[0][0] + else: + loss_weight = -1.0 + + return features, truth, node_truth, loss_weight + + def _create_graph( + self, + features: List[Tuple[float, ...]], + truth: Tuple[Any, ...], + node_truth: Optional[List[Tuple[Any, ...]]] = None, + loss_weight: Optional[float] = None, + ) -> Data: + """Create Pytorch Data (i.e. graph) object. + + Args: + features: List of tuples, containing event features. + truth: List of tuples, containing truth information. + node_truth: List of tuples, containing node-level truth. + loss_weight: A weight associated with the event for weighing the + loss. + + Returns: + Graph object. + """ + # Convert nested list to simple dict + truth_dict = { + key: truth[index] for index, key in enumerate(self._truth) + } + + # Define custom labels + labels_dict = self._get_labels(truth_dict) + + # Convert nested list to simple dict + if node_truth is not None: + node_truth_array = np.asarray(node_truth) + assert self._node_truth is not None + node_truth_dict = { + key: node_truth_array[:, index] + for index, key in enumerate(self._node_truth) + } + + # Create list of truth dicts with labels + truth_dicts = [labels_dict, truth_dict] + if node_truth is not None: + truth_dicts.append(node_truth_dict) + + # Catch cases with no reconstructed pulses + if len(features): + node_features = np.asarray(features)[ + :, 1: + ] # first entry is index column + else: + node_features = np.array([]).reshape((0, len(self._features) - 1)) + + # Construct graph data object + assert self._graph_definition is not None + graph = self._graph_definition( + node_features=node_features, + node_feature_names=self._features[ + 1: + ], # first entry is index column + truth_dicts=truth_dicts, + custom_label_functions=self._label_fns, + loss_weight_column=self._loss_weight_column, + loss_weight=loss_weight, + loss_weight_default_value=self._loss_weight_default_value, + data_path=self._path, + ) + return graph + + def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]: + """Return dictionary of labels, to be added as graph attributes.""" + if "pid" in truth_dict.keys(): + abs_pid = abs(truth_dict["pid"]) + sim_type = truth_dict["sim_type"] + + labels_dict = { + self._index_column: truth_dict[self._index_column], + "muon": int(abs_pid == 13), + "muon_stopped": int(truth_dict.get("stopped_muon") == 1), + "noise": int((abs_pid == 1) & (sim_type != "data")), + "neutrino": int( + (abs_pid != 13) & (abs_pid != 1) + ), # @TODO: `abs_pid in [12,14,16]`? + "v_e": int(abs_pid == 12), + "v_u": int(abs_pid == 14), + "v_t": int(abs_pid == 16), + "track": int( + (abs_pid == 14) & (truth_dict["interaction_type"] == 1) + ), + "dbang": self._get_dbang_label(truth_dict), + "corsika": int(abs_pid > 20), + } + else: + labels_dict = { + self._index_column: truth_dict[self._index_column], + "muon": -1, + "muon_stopped": -1, + "noise": -1, + "neutrino": -1, + "v_e": -1, + "v_u": -1, + "v_t": -1, + "track": -1, + "dbang": -1, + "corsika": -1, + } + return labels_dict + + def _get_dbang_label(self, truth_dict: Dict[str, Any]) -> int: + """Get label for double-bang classification.""" + try: + label = int(truth_dict["dbang_decay_length"] > -1) + return label + except KeyError: + return -1
+ + + +
+[docs] +class EnsembleDataset(torch.utils.data.ConcatDataset): + """Construct a single dataset from a collection of datasets.""" + + def __init__(self, datasets: Iterable[Dataset]) -> None: + """Construct a single dataset from a collection of datasets. + + Args: + datasets: A collection of Datasets + """ + super().__init__(datasets=datasets)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/parquet/parquet_dataset.html b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html new file mode 100644 index 000000000..35d6b34d7 --- /dev/null +++ b/_modules/graphnet/data/dataset/parquet/parquet_dataset.html @@ -0,0 +1,500 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.dataset.parquet.parquet_dataset — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.dataset.parquet.parquet_dataset

+"""`Dataset` class(es) for reading from Parquet files."""
+
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
+
+import numpy as np
+import awkward as ak
+
+from graphnet.data.dataset.dataset import Dataset, ColumnMissingException
+
+
+
+[docs] +class ParquetDataset(Dataset): + """Pytorch dataset for reading from Parquet files.""" + + # Implementing abstract method(s) + def _init(self) -> None: + # Check(s) + if not isinstance(self._path, list): + + assert isinstance(self._path, str) + + assert self._path.endswith( + ".parquet" + ), f"Format of input file `{self._path}` is not supported" + + assert ( + self._node_truth is None + ), "Argument `node_truth` is currently not supported." + assert ( + self._node_truth_table is None + ), "Argument `node_truth_table` is currently not supported." + assert ( + self._string_selection is None + ), "Argument `string_selection` is currently not supported" + + # Set custom member variable(s) + if not isinstance(self._path, list): + self._parquet_hook = ak.from_parquet(self._path, lazy=False) + else: + self._parquet_hook = ak.concatenate( + ak.from_parquet(file) for file in self._path + ) + + def _get_all_indices(self) -> List[int]: + return np.arange( + len( + ak.to_numpy( + self._parquet_hook[self._truth_table][self._index_column] + ).tolist() + ) + ).tolist() + + def _get_event_index( + self, sequential_index: Optional[int] + ) -> Optional[int]: + index: Optional[int] + if sequential_index is None: + index = None + else: + index = cast(List[int], self._indices)[sequential_index] + + return index + + def _format_dictionary_result( + self, dictionary: Dict + ) -> List[Tuple[Any, ...]]: + """Convert the output of `ak.to_list()` into a list of tuples.""" + # All scalar values + if all(map(np.isscalar, dictionary.values())): + return [tuple(dictionary.values())] + + # All arrays should have same length + array_lengths = [ + len(values) + for values in dictionary.values() + if not np.isscalar(values) + ] + assert len(set(array_lengths)) == 1, ( + f"Arrays in {dictionary} have differing lengths " + f"({set(array_lengths)})." + ) + nb_elements = array_lengths[0] + + # Broadcast scalars + for key in dictionary: + value = dictionary[key] + if np.isscalar(value): + dictionary[key] = np.repeat( + value, repeats=nb_elements + ).tolist() + + return list(map(tuple, list(zip(*dictionary.values())))) + +
+[docs] + def query_table( + self, + table: str, + columns: Union[List[str], str], + sequential_index: Optional[int] = None, + selection: Optional[str] = None, + ) -> List[Tuple[Any, ...]]: + """Query table at a specific index, optionally with some selection.""" + # Check(s) + assert ( + selection is None + ), "Argument `selection` is currently not supported" + + index = self._get_event_index(sequential_index) + + try: + if index is None: + ak_array = self._parquet_hook[table][columns][:] + else: + ak_array = self._parquet_hook[table][columns][index] + except ValueError as e: + if "does not exist (not in record)" in str(e): + raise ColumnMissingException(str(e)) + else: + raise e + + output = ak_array.to_list() + + result: List[Tuple[Any, ...]] = [] + + # Querying single index + if isinstance(output, dict): + assert list(output.keys()) == columns + result = self._format_dictionary_result(output) + + # Querying entire columm + elif isinstance(output, list): + for dictionary in output: + assert list(dictionary.keys()) == columns + result.extend(self._format_dictionary_result(dictionary)) + + return result
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html new file mode 100644 index 000000000..a6734e1c8 --- /dev/null +++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset.html @@ -0,0 +1,515 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.dataset.sqlite.sqlite_dataset — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.dataset.sqlite.sqlite_dataset

+"""`Dataset` class(es) for reading data from SQLite databases."""
+
+from typing import Any, List, Optional, Tuple, Union
+import pandas as pd
+import sqlite3
+
+from graphnet.data.dataset.dataset import Dataset, ColumnMissingException
+
+
+
+[docs] +class SQLiteDataset(Dataset): + """Pytorch dataset for reading data from SQLite databases.""" + + # Implementing abstract method(s) + def _init(self) -> None: + # Check(s) + self._database_list: Optional[List[str]] + if isinstance(self._path, list): + self._database_list = self._path + self._all_connections_established = False + self._all_connections: List[sqlite3.Connection] = [] + else: + self._database_list = None + assert isinstance(self._path, str) + assert self._path.endswith( + ".db" + ), f"Format of input file `{self._path}` is not supported." + + if self._database_list is not None: + self._current_database: Optional[int] = None + + # Set custom member variable(s) + self._features_string = ", ".join(self._features) + self._truth_string = ", ".join(self._truth) + if self._node_truth: + self._node_truth_string = ", ".join(self._node_truth) + + self._conn: Optional[sqlite3.Connection] = None + + def _post_init(self) -> None: + self._close_connection() + +
+[docs] + def query_table( + self, + table: str, + columns: Union[List[str], str], + sequential_index: Optional[int] = None, + selection: Optional[str] = None, + ) -> List[Tuple[Any, ...]]: + """Query table at a specific index, optionally with some selection.""" + # Check(s) + if isinstance(columns, list): + columns = ", ".join(columns) + + if not selection: # I.e., `None` or `""` + selection = "1=1" # Identically true, to select all + + index = self._get_event_index(sequential_index) + + # Query table + assert index is not None + self._establish_connection(index) + try: + assert self._conn + if sequential_index is None: + combined_selections = selection + else: + combined_selections = ( + f"{self._index_column} = {index} and {selection}" + ) + + result = self._conn.execute( + f"SELECT {columns} FROM {table} WHERE " + f"{combined_selections}" + ).fetchall() + except sqlite3.OperationalError as e: + if "no such column" in str(e): + raise ColumnMissingException(str(e)) + else: + raise e + return result
+ + + def _get_all_indices(self) -> List[int]: + self._establish_connection(0) + indices = pd.read_sql_query( + f"SELECT {self._index_column} FROM {self._truth_table}", self._conn + ) + self._close_connection() + return indices.values.ravel().tolist() + + def _get_event_index( + self, sequential_index: Optional[int] + ) -> Optional[int]: + index: int = 0 + if sequential_index is not None: + index_ = self._indices[sequential_index] + if self._database_list is None: + assert isinstance(index_, int) + index = index_ + else: + assert isinstance(index_, list) + index = index_[0] + return index + + # Custom, internal method(s) + # @TODO: Is it necessary to return anything here? + def _establish_connection(self, i: int) -> "SQLiteDataset": + """Make sure that a sqlite3 connection is open.""" + if self._database_list is None: + assert isinstance(self._path, str) + if self._conn is None: + self._conn = sqlite3.connect(self._path) + else: + indices = self._indices[i] + assert isinstance(indices, list) + if self._conn is None: + if self._all_connections_established is False: + self._all_connections = [] + for database in self._database_list: + con = sqlite3.connect(database) + self._all_connections.append(con) + self._all_connections_established = True + self._conn = self._all_connections[indices[1]] + if indices[1] != self._current_database: + self._conn = self._all_connections[indices[1]] + self._current_database = indices[1] + return self + + # @TODO: Is it necessary to return anything here? + def _close_connection(self) -> "SQLiteDataset": + """Make sure that no sqlite3 connection is open. + + This is necessary to calls this before passing to + `torch.DataLoader` such that the dataset replica on each worker + is required to create its own connection (thereby avoiding + `sqlite3.DatabaseError: database disk image is malformed` errors + due to inability to use sqlite3 connection accross processes. + """ + if self._conn is not None: + self._conn.close() + del self._conn + self._conn = None + if self._database_list is not None: + if self._all_connections_established: + for con in self._all_connections: + con.close() + del self._all_connections + self._all_connections_established = False + self._conn = None + return self
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html new file mode 100644 index 000000000..44acbb9c5 --- /dev/null +++ b/_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.html @@ -0,0 +1,515 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.dataset.sqlite.sqlite_dataset_perturbed — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.dataset.sqlite.sqlite_dataset_perturbed

+"""`Dataset` class(es) for reading perturbed data from SQLite databases."""
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from numpy.random import default_rng, Generator
+import torch
+from torch_geometric.data import Data
+
+from .sqlite_dataset import SQLiteDataset
+
+
+
+[docs] +class SQLiteDatasetPerturbed(SQLiteDataset): + """Pytorch dataset for reading perturbed data from SQLite databases. + + This including a pre-processing step, where the input data is randomly + perturbed according to given per-feature "noise" levels. This is intended + to test the stability of a trained model under small changes to the input + parameters. + """ + + def __init__( + self, + path: Union[str, List[str]], + pulsemaps: Union[str, List[str]], + features: List[str], + truth: List[str], + *, + perturbation_dict: Dict[str, float], + node_truth: Optional[List[str]] = None, + index_column: str = "event_no", + truth_table: str = "truth", + node_truth_table: Optional[str] = None, + string_selection: Optional[List[int]] = None, + selection: Optional[List[int]] = None, + dtype: torch.dtype = torch.float32, + loss_weight_table: Optional[str] = None, + loss_weight_column: Optional[str] = None, + loss_weight_default_value: Optional[float] = None, + seed: Optional[Union[int, Generator]] = None, + ): + """Construct SQLiteDatasetPerturbed. + + Args: + path: Path to the file(s) from which this `Dataset` should read. + pulsemaps: Name(s) of the pulse map series that should be used to + construct the nodes on the individual graph objects, and their + features. Multiple pulse series maps can be used, e.g., when + different DOM types are stored in different maps. + features: List of columns in the input files that should be used as + node features on the graph objects. + truth: List of event-level columns in the input files that should + be used added as attributes on the graph objects. + perturbation_dict (Dict[str, float]): Dictionary mapping a feature + name to a standard deviation according to which the values for + this feature should be randomly perturbed. + node_truth: List of node-level columns in the input files that + should be used added as attributes on the graph objects. + index_column: Name of the column in the input files that contains + unique indicies to identify and map events across tables. + truth_table: Name of the table containing event-level truth + information. + node_truth_table: Name of the table containing node-level truth + information. + string_selection: Subset of strings for which data should be read + and used to construct graph objects. Defaults to None, meaning + all strings for which data exists are used. + selection: List of indicies (in `index_column`) of the events in + the input files that should be read. Defaults to None, meaning + that all events in the input files are read. + dtype: Type of the feature tensor on the graph objects returned. + loss_weight_table: Name of the table containing per-event loss + weights. + loss_weight_column: Name of the column in `loss_weight_table` + containing per-event loss weights. This is also the name of the + corresponding attribute assigned to the graph object. + loss_weight_default_value: Default per-event loss weight. + NOTE: This default value is only applied when + `loss_weight_table` and `loss_weight_column` are specified, and + in this case to events with no value in the corresponding + table/column. That is, if no per-event loss weight table/column + is provided, this value is ignored. Defaults to None. + seed: Optional seed for random number generation. Defaults to None. + """ + # Base class constructor + super().__init__( + path=path, + pulsemaps=pulsemaps, + features=features, + truth=truth, + node_truth=node_truth, + index_column=index_column, + truth_table=truth_table, + node_truth_table=node_truth_table, + string_selection=string_selection, + selection=selection, + dtype=dtype, + loss_weight_table=loss_weight_table, + loss_weight_column=loss_weight_column, + loss_weight_default_value=loss_weight_default_value, + ) + + # Custom member variables + assert isinstance(perturbation_dict, dict) + assert len(set(perturbation_dict.keys())) == len( + perturbation_dict.keys() + ) + self._perturbation_dict = perturbation_dict + + self._perturbation_cols = [ + self._features.index(key) for key in self._perturbation_dict.keys() + ] + + if seed is not None: + if isinstance(seed, int): + self.rng = default_rng(seed) + elif isinstance(seed, Generator): + self.rng = seed + else: + raise ValueError( + "Invalid seed. Must be an int or a numpy Generator." + ) + else: + self.rng = default_rng() + + def __getitem__(self, sequential_index: int) -> Data: + """Return graph `Data` object at `index`.""" + if not (0 <= sequential_index < len(self)): + raise IndexError( + f"Index {sequential_index} not in range [0, {len(self) - 1}]" + ) + features, truth, node_truth, loss_weight = self._query( + sequential_index + ) + perturbed_features = self._perturb_features(features) + graph = self._create_graph( + perturbed_features, truth, node_truth, loss_weight + ) + return graph + + def _perturb_features( + self, features: List[Tuple[float, ...]] + ) -> List[Tuple[float, ...]]: + features_array = np.array(features) + perturbed_features = self.rng.normal( + loc=features_array[:, self._perturbation_cols], + scale=np.array( + list(self._perturbation_dict.values()), dtype=np.float + ), + ) + features_array[:, self._perturbation_cols] = perturbed_features + return features_array.tolist()
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/extractors/i3extractor.html b/_modules/graphnet/data/extractors/i3extractor.html index cae615098..0680557f4 100644 --- a/_modules/graphnet/data/extractors/i3extractor.html +++ b/_modules/graphnet/data/extractors/i3extractor.html @@ -464,7 +464,7 @@

Source code for Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3featureextractor.html b/_modules/graphnet/data/extractors/i3featureextractor.html index 06d9d3e68..de384a97e 100644 --- a/_modules/graphnet/data/extractors/i3featureextractor.html +++ b/_modules/graphnet/data/extractors/i3featureextractor.html @@ -647,7 +647,7 @@

Source c Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3genericextractor.html b/_modules/graphnet/data/extractors/i3genericextractor.html index 0324f9087..9c7d786f3 100644 --- a/_modules/graphnet/data/extractors/i3genericextractor.html +++ b/_modules/graphnet/data/extractors/i3genericextractor.html @@ -636,7 +636,7 @@

Source c Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html index 21824f54b..ee80f730f 100644 --- a/_modules/graphnet/data/extractors/i3hybridrecoextractor.html +++ b/_modules/graphnet/data/extractors/i3hybridrecoextractor.html @@ -400,7 +400,7 @@

Sourc Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html index 2a3a45023..d295f81a2 100644 --- a/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html +++ b/_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.html @@ -407,7 +407,7 @@

Sou Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3particleextractor.html b/_modules/graphnet/data/extractors/i3particleextractor.html index 8d17ef0ee..60463b59a 100644 --- a/_modules/graphnet/data/extractors/i3particleextractor.html +++ b/_modules/graphnet/data/extractors/i3particleextractor.html @@ -392,7 +392,7 @@

Source Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3pisaextractor.html b/_modules/graphnet/data/extractors/i3pisaextractor.html index 635ba4537..0752dc39a 100644 --- a/_modules/graphnet/data/extractors/i3pisaextractor.html +++ b/_modules/graphnet/data/extractors/i3pisaextractor.html @@ -385,7 +385,7 @@

Source code Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3quesoextractor.html b/_modules/graphnet/data/extractors/i3quesoextractor.html index bcac49e76..6ad9572a6 100644 --- a/_modules/graphnet/data/extractors/i3quesoextractor.html +++ b/_modules/graphnet/data/extractors/i3quesoextractor.html @@ -395,7 +395,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3retroextractor.html b/_modules/graphnet/data/extractors/i3retroextractor.html index dea96d47a..b76bed436 100644 --- a/_modules/graphnet/data/extractors/i3retroextractor.html +++ b/_modules/graphnet/data/extractors/i3retroextractor.html @@ -467,7 +467,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3splinempeextractor.html b/_modules/graphnet/data/extractors/i3splinempeextractor.html index 79f5f9457..dd8d383bd 100644 --- a/_modules/graphnet/data/extractors/i3splinempeextractor.html +++ b/_modules/graphnet/data/extractors/i3splinempeextractor.html @@ -379,7 +379,7 @@

Source Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3truthextractor.html b/_modules/graphnet/data/extractors/i3truthextractor.html index 330ad508b..19d005d0d 100644 --- a/_modules/graphnet/data/extractors/i3truthextractor.html +++ b/_modules/graphnet/data/extractors/i3truthextractor.html @@ -781,7 +781,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/i3tumextractor.html b/_modules/graphnet/data/extractors/i3tumextractor.html index d22868f07..89adb77c2 100644 --- a/_modules/graphnet/data/extractors/i3tumextractor.html +++ b/_modules/graphnet/data/extractors/i3tumextractor.html @@ -382,7 +382,7 @@

Source code Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/utilities/collections.html b/_modules/graphnet/data/extractors/utilities/collections.html index 11ccbd8f5..14835acc6 100644 --- a/_modules/graphnet/data/extractors/utilities/collections.html +++ b/_modules/graphnet/data/extractors/utilities/collections.html @@ -436,7 +436,7 @@

Sourc Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/utilities/frames.html b/_modules/graphnet/data/extractors/utilities/frames.html index 4dfe93249..656cb6033 100644 --- a/_modules/graphnet/data/extractors/utilities/frames.html +++ b/_modules/graphnet/data/extractors/utilities/frames.html @@ -439,7 +439,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/extractors/utilities/types.html b/_modules/graphnet/data/extractors/utilities/types.html index f2d4ecd65..60de6d341 100644 --- a/_modules/graphnet/data/extractors/utilities/types.html +++ b/_modules/graphnet/data/extractors/utilities/types.html @@ -660,7 +660,7 @@

Source code Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/parquet/parquet_dataconverter.html b/_modules/graphnet/data/parquet/parquet_dataconverter.html index 7cf5c3507..455153448 100644 --- a/_modules/graphnet/data/parquet/parquet_dataconverter.html +++ b/_modules/graphnet/data/parquet/parquet_dataconverter.html @@ -407,7 +407,7 @@

Source c Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/pipeline.html b/_modules/graphnet/data/pipeline.html new file mode 100644 index 000000000..90602e63c --- /dev/null +++ b/_modules/graphnet/data/pipeline.html @@ -0,0 +1,593 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.pipeline — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.pipeline

+"""Class(es) used for analysis in PISA."""
+
+from abc import ABC
+import dill
+from functools import reduce
+import os
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import pandas as pd
+from pytorch_lightning import Trainer
+import sqlite3
+import torch
+from torch.utils.data import DataLoader
+
+from graphnet.data.sqlite.sqlite_utilities import create_table_and_save_to_sql
+from graphnet.training.utils import get_predictions, make_dataloader
+from graphnet.models.graphs import GraphDefinition
+
+from graphnet.utilities.logging import Logger
+
+
+
+[docs] +class InSQLitePipeline(ABC, Logger): + """Create a SQLite database for PISA analysis. + + The database will contain truth and GNN predictions and, if available, + RETRO reconstructions. + """ + + def __init__( + self, + module_dict: Dict, + features: List[str], + truth: List[str], + device: torch.device, + retro_table_name: str = "retro", + outdir: Optional[str] = None, + batch_size: int = 100, + n_workers: int = 10, + pipeline_name: str = "pipeline", + ): + """Initialise the pipeline. + + Args: + module_dict: A dictionary with GNN modules from GraphNet. E.g. + {'energy': gnn_module_for_energy_regression} + features: List of input features for the GNN modules. + truth: List of truth for the GNN ModuleList. + device: The device used for computation. + retro_table_name: Name of the retro table for. + outdir: the directory in which the pipeline database will be + stored. + batch_size: Batch size for inference. + n_workers: Number of workers used in dataloading. + pipeline_name: Name of the pipeline. If such a pipeline already + exists, an error will be prompted to avoid overwriting. + """ + self._pipeline_name = pipeline_name + self._device = device + self.n_workers = n_workers + self._features = features + self._truth = truth + self._batch_size = batch_size + self._outdir = outdir + self._module_dict = module_dict + self._retro_table_name = retro_table_name + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + def __call__( + self, + database: str, + pulsemap: str, + graph_definition: GraphDefinition, + chunk_size: int = 1000000, + ) -> None: + """Run inference of each field in self._module_dict[target]['']. + + Args: + database: Path to database with pulsemap and truth. + pulsemap: Name of pulsemaps. + graph_definition: GraphDefinition for Dataset + chunk_size: database will be sliced in chunks of size `chunk_size`. + Use this parameter to control memory usage. + """ + outdir = self._get_outdir(database) + if isinstance( + self._device, str + ): # Because pytorch lightning insists on breaking pytorch cuda device naming scheme + device = int(self._device[-1]) + if not os.path.isdir(outdir): + dataloaders, event_batches = self._setup_dataloaders( + graph_definition=graph_definition, + chunk_size=chunk_size, + db=database, + pulsemap=pulsemap, + selection=None, + persistent_workers=False, + ) + i = 0 + for dataloader in dataloaders: + self.info("CHUNK %s / %s" % (i, len(dataloaders))) + df = self._inference(device, dataloader) + truth = self._get_truth(database, event_batches[i].tolist()) + retro = self._get_retro(database, event_batches[i].tolist()) + self._append_to_pipeline(outdir, truth, retro, df) + i += 1 + else: + self.info(outdir) + self.info( + "WARNING - Pipeline named %s already exists! \n Please rename pipeline!" + % self._pipeline_name + ) + + def _setup_dataloaders( + self, + chunk_size: int, + db: str, + pulsemap: str, + graph_definition: GraphDefinition, + selection: Optional[List[int]] = None, + persistent_workers: bool = False, + ) -> Tuple[List[DataLoader], List[np.ndarray]]: + if selection is None: + selection = self._get_all_event_nos(db) + n_chunks = np.ceil(len(selection) / chunk_size) + event_batches = np.array_split(selection, n_chunks) + dataloaders = [] + for batch in event_batches: + dataloaders.append( + make_dataloader( + db=db, + graph_definition=graph_definition, + pulsemaps=pulsemap, + features=self._features, + truth=self._truth, + batch_size=self._batch_size, + shuffle=False, + selection=batch.tolist(), + num_workers=self.n_workers, + persistent_workers=persistent_workers, + ) + ) + return dataloaders, event_batches + + def _get_all_event_nos(self, db: str) -> List[int]: + with sqlite3.connect(db) as con: + query = "SELECT event_no FROM truth" + selection = pd.read_sql(query, con).values.ravel().tolist() + return selection + + def _combine_outputs(self, dataframes: List[pd.DataFrame]) -> pd.DataFrame: + return reduce(lambda x, y: pd.merge(x, y, on="event_no"), dataframes) + + def _inference( + self, device: torch.device, dataloader: DataLoader + ) -> pd.DataFrame: + dataframes = [] + for target in self._module_dict.keys(): + # dataloader = iter(dataloader) + trainer = Trainer(devices=[device], accelerator="gpu") + model = torch.load( + self._module_dict[target]["path"], + map_location="cpu", + pickle_module=dill, + ) + model.eval() + model.inference() + results = get_predictions( + trainer, + model, + dataloader, + self._module_dict[target]["output_column_names"], + additional_attributes=["event_no"], + ) + dataframes.append( + results.sort_values("event_no").reset_index(drop=True) + ) + df = self._combine_outputs(dataframes) + return df + + def _get_outdir(self, database: str) -> str: + if self._outdir is None: + database_name = database.split("/")[-3] + outdir = ( + database.split(database_name)[0] + + database_name + + "/pipelines/" + + self._pipeline_name + ) + else: + outdir = self._outdir + return outdir + + def _get_truth(self, database: str, selection: List[int]) -> pd.DataFrame: + with sqlite3.connect(database) as con: + query = "SELECT * FROM truth WHERE event_no in %s" % str( + tuple(selection) + ) + truth = pd.read_sql(query, con) + return truth + + def _get_retro(self, database: str, selection: List[int]) -> pd.DataFrame: + try: + with sqlite3.connect(database) as con: + query = "SELECT * FROM %s WHERE event_no in %s" % ( + self._retro_table_name, + str(tuple(selection)), + ) + retro = pd.read_sql(query, con) + return retro + except: # noqa: E722 + self.info("%s table does not exist" % self._retro_table_name) + + def _append_to_pipeline( + self, + outdir: str, + truth: pd.DataFrame, + retro: pd.DataFrame, + df: pd.DataFrame, + ) -> None: + os.makedirs(outdir, exist_ok=True) + pipeline_database = outdir + "/%s.db" % self._pipeline_name + create_table_and_save_to_sql(df, "reconstruction", pipeline_database) + create_table_and_save_to_sql(truth, "truth", pipeline_database) + if isinstance(retro, pd.DataFrame): + create_table_and_save_to_sql( + retro, self._retro_table_name, pipeline_database + )
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html index a067a451a..69fa21bbe 100644 --- a/_modules/graphnet/data/sqlite/sqlite_dataconverter.html +++ b/_modules/graphnet/data/sqlite/sqlite_dataconverter.html @@ -716,7 +716,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/sqlite/sqlite_utilities.html b/_modules/graphnet/data/sqlite/sqlite_utilities.html index 29d4e537d..a910e007d 100644 --- a/_modules/graphnet/data/sqlite/sqlite_utilities.html +++ b/_modules/graphnet/data/sqlite/sqlite_utilities.html @@ -515,7 +515,7 @@

Source code fo Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/utilities/parquet_to_sqlite.html b/_modules/graphnet/data/utilities/parquet_to_sqlite.html index e645a35f4..f3944caa0 100644 --- a/_modules/graphnet/data/utilities/parquet_to_sqlite.html +++ b/_modules/graphnet/data/utilities/parquet_to_sqlite.html @@ -533,7 +533,7 @@

Source cod Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/utilities/random.html b/_modules/graphnet/data/utilities/random.html index f297fc67c..eb3a7f85c 100644 --- a/_modules/graphnet/data/utilities/random.html +++ b/_modules/graphnet/data/utilities/random.html @@ -375,7 +375,7 @@

Source code for graph Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/data/utilities/string_selection_resolver.html b/_modules/graphnet/data/utilities/string_selection_resolver.html index bd7909f33..361902f54 100644 --- a/_modules/graphnet/data/utilities/string_selection_resolver.html +++ b/_modules/graphnet/data/utilities/string_selection_resolver.html @@ -676,7 +676,7 @@

So Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/deployment/i3modules/graphnet_module.html b/_modules/graphnet/deployment/i3modules/graphnet_module.html new file mode 100644 index 000000000..994461c04 --- /dev/null +++ b/_modules/graphnet/deployment/i3modules/graphnet_module.html @@ -0,0 +1,817 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.deployment.i3modules.graphnet_module — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.deployment.i3modules.graphnet_module

+"""Class(es) for deploying GraphNeT models in icetray as I3Modules."""
+from abc import abstractmethod
+from typing import TYPE_CHECKING, Any, List, Union, Dict, Tuple, Optional
+
+import dill
+import numpy as np
+import torch
+from torch_geometric.data import Data, Batch
+
+from graphnet.data.extractors import (
+    I3FeatureExtractor,
+    I3FeatureExtractorIceCubeUpgrade,
+)
+from graphnet.models import Model, StandardModel
+from graphnet.models.graphs import GraphDefinition
+from graphnet.utilities.imports import has_icecube_package
+from graphnet.utilities.config import ModelConfig
+
+if has_icecube_package() or TYPE_CHECKING:
+    from icecube.icetray import (
+        I3Module,
+        I3Frame,
+    )  # pyright: reportMissingImports=false
+    from icecube.dataclasses import (
+        I3Double,
+        I3MapKeyVectorDouble,
+    )  # pyright: reportMissingImports=false
+    from icecube import dataclasses, dataio, icetray
+
+
+
+[docs] +class GraphNeTI3Module: + """Base I3 Module for GraphNeT. + + Contains methods for extracting pulsemaps, producing graphs and writing to + frames. + """ + + def __init__( + self, + graph_definition: GraphDefinition, + pulsemap: str, + features: List[str], + pulsemap_extractor: Union[ + List[I3FeatureExtractor], I3FeatureExtractor + ], + gcd_file: str, + ): + """I3Module Constructor. + + Arguments: + graph_definition: An instance of GraphDefinition. E.g. KNNGraph. + pulsemap: the pulse map on which the module functions + features: the features that is used from the pulse map. + E.g. [dom_x, dom_y, dom_z, charge] + pulsemap_extractor: The I3FeatureExtractor used to extract the + pulsemap from the I3Frames + gcd_file: Path to the associated gcd-file. + """ + assert isinstance(graph_definition, GraphDefinition) + self._graph_definition = graph_definition + self._pulsemap = pulsemap + self._features = features + assert isinstance(gcd_file, str), "gcd_file must be string" + self._gcd_file = gcd_file + if isinstance(pulsemap_extractor, list): + self._i3_extractors = pulsemap_extractor + else: + self._i3_extractors = [pulsemap_extractor] + + for i3_extractor in self._i3_extractors: + i3_extractor.set_files(i3_file="", gcd_file=self._gcd_file) + + @abstractmethod + def __call__(self, frame: I3Frame) -> bool: + """Define here how the module acts on the frame. + + Must return True if successful. + + Return True # SUPER IMPORTANT + """ + + def _make_graph( + self, frame: I3Frame + ) -> Data: # py-l-i-n-t-:- -d-i-s-able=invalid-name + """Process Physics I3Frame into graph.""" + # Extract features + node_features = self._extract_feature_array_from_frame(frame) + # Prepare graph data + if len(node_features) > 0: + data = self._graph_definition( + node_features=node_features, + node_feature_names=self._features, + ) + return Batch.from_data_list([data]) + else: + return None + + def _extract_feature_array_from_frame(self, frame: I3Frame) -> np.array: + """Apply the I3FeatureExtractors to the I3Frame. + + Arguments: + frame: Physics I3Frame (PFrame) + + Returns: + array with pulsemap + """ + features = None + for i3extractor in self._i3_extractors: + feature_dict = i3extractor(frame) + features_pulsemap = np.array( + [feature_dict[key] for key in self._features] + ).T + if features is None: + features = features_pulsemap + else: + features = np.concatenate( + (features, features_pulsemap), axis=0 + ) + return features + + def _add_to_frame(self, frame: I3Frame, data: Dict[str, Any]) -> I3Frame: + """Add every field in data to I3Frame. + + Arguments: + frame: I3Frame (physics) + data: Dictionary containing content that will be written to frame. + + Returns: + frame: Same I3Frame as input, but with the new entries + """ + assert isinstance( + data, dict + ), f"data must be of type dict. Got {type(data)}" + for key in data.keys(): + if key not in frame: + frame.Put(key, data[key]) + return frame
+ + + +
+[docs] +class I3InferenceModule(GraphNeTI3Module): + """General class for inference on i3 frames.""" + + def __init__( + self, + pulsemap: str, + features: List[str], + pulsemap_extractor: Union[ + List[I3FeatureExtractor], I3FeatureExtractor + ], + model_config: Union[ModelConfig, str], + state_dict: str, + model_name: str, + gcd_file: str, + prediction_columns: Optional[Union[List[str], str]] = None, + ): + """General class for inference on I3Frames (physics). + + Arguments: + pulsemap: the pulsmap that the model is expecting as input. + features: the features of the pulsemap that the model is expecting. + pulsemap_extractor: The extractor used to extract the pulsemap. + model_config: The ModelConfig (or path to it) that summarizes the + model used for inference. + state_dict: Path to state_dict containing the learned weights. + model_name: The name used for the model. Will help define the + named entry in the I3Frame. E.g. "dynedge". + gcd_file: path to associated gcd file. + prediction_columns: column names for the predictions of the model. + Will help define the named entry in the I3Frame. + E.g. ['energy_reco']. Optional. + """ + # Construct model & load weights + self.model = Model.from_config(model_config, trust=True) + self.model.load_state_dict(state_dict) + + super().__init__( + pulsemap=pulsemap, + features=features, + pulsemap_extractor=pulsemap_extractor, + gcd_file=gcd_file, + graph_definition=self.model._graph_definition, + ) + self.model.inference() + + self.model.to("cpu") + if prediction_columns is not None: + if isinstance(prediction_columns, str): + self.prediction_columns = [prediction_columns] + else: + self.prediction_columns = prediction_columns + else: + self.prediction_columns = self.model.prediction_labels + + self.model_name = model_name + + def __call__(self, frame: I3Frame) -> bool: + """Write predictions from model to frame.""" + # inference + graph = self._make_graph(frame) + if graph is not None: + predictions = self._inference(graph) + else: + predictions = np.repeat( + [np.nan], len(self.prediction_columns) + ).reshape(-1, len(self.prediction_columns)) + # Check dimensions of predictions and prediction columns + if len(predictions.shape) > 1: + dim = predictions.shape[1] + else: + dim = len(predictions) + assert dim == len( + self.prediction_columns + ), f"""predictions have shape {dim} but \n + prediction columns have [{self.prediction_columns}]""" + + # Build Dictionary of predictions + data = {} + assert predictions.shape[0] == 1 + for i in range(dim if isinstance(dim, int) else len(dim)): + try: + assert len(predictions[:, i]) == 1 + data[ + self.model_name + "_" + self.prediction_columns[i] + ] = I3Double(float(predictions[:, i][0])) + except IndexError: + data[ + self.model_name + "_" + self.prediction_columns[i] + ] = I3Double(predictions[0]) + + # Submission methods + frame = self._add_to_frame(frame=frame, data=data) + return True + + def _inference(self, data: Data) -> np.ndarray: + # Perform inference + task_predictions = self.model(data) + assert ( + len(task_predictions) == 1 + ), f"""This method assumes a single task. \n + Got {len(task_predictions)} tasks.""" + return self.model(data)[0].detach().numpy()
+ + + +
+[docs] +class I3PulseCleanerModule(I3InferenceModule): + """A specialized module for pulse cleaning. + + It is assumed that the model provided has been trained for this. + """ + + def __init__( + self, + pulsemap: str, + features: List[str], + pulsemap_extractor: Union[ + List[I3FeatureExtractor], I3FeatureExtractor + ], + model_config: str, + state_dict: str, + model_name: str, + *, + gcd_file: str, + threshold: float = 0.7, + discard_empty_events: bool = False, + prediction_columns: Optional[Union[List[str], str]] = None, + ): + """General class for inference on I3Frames (physics). + + Arguments: + pulsemap: the pulsmap that the model is expecting as input + (the one that is being cleaned). + features: the features of the pulsemap that the model is expecting. + pulsemap_extractor: The extractor used to extract the pulsemap. + model_config: The ModelConfig (or path to it) that summarizes the + model used for inference. + state_dict: Path to state_dict containing the learned weights. + model_name: The name used for the model. Will help define the named + entry in the I3Frame. E.g. "dynedge". + gcd_file: path to associated gcd file. + threshold: the threshold for being considered a positive case. + E.g., predictions >= threshold will be considered + to be signal, all else noise. + discard_empty_events: When true, this flag will eliminate events + whose cleaned pulse series are empty. Can be used + to speed up processing especially for noise + simulation, since it will not do any writing or + further calculations. + prediction_columns: column names for the predictions of the model. + Will help define the named entry in the I3Frame. + E.g. ['energy_reco']. Optional. + """ + super().__init__( + pulsemap=pulsemap, + features=features, + pulsemap_extractor=pulsemap_extractor, + model_config=model_config, + state_dict=state_dict, + model_name=model_name, + prediction_columns=prediction_columns, + gcd_file=gcd_file, + ) + self._threshold = threshold + self._predictions_key = f"{pulsemap}_{model_name}_Predictions" + self._total_pulsemap_name = f"{pulsemap}_{model_name}_Pulses" + self._discard_empty_events = discard_empty_events + + def __call__(self, frame: I3Frame) -> bool: + """Add a cleaned pulsemap to frame.""" + # inference + gcd_file = self._gcd_file + graph = self._make_graph(frame) + if graph is None: # If there is no pulses to clean + return False + predictions = self._inference(graph) + if self._discard_empty_events: + if sum(predictions > self._threshold) == 0: + return False + + if len(predictions.shape) == 1: + predictions = predictions.reshape(-1, 1) + + assert predictions.shape[1] == 1 + + # Build Dictionary of predictions + data = {} + + predictions_map = self._construct_prediction_map( + frame=frame, predictions=predictions + ) + + # Adds the raw predictions to dictionary + if self._predictions_key not in frame.keys(): + data[self._predictions_key] = predictions_map + + # Create a pulse map mask, indicating the pulses that are over + # threshold (e.g. identified as signal) and therefore should be kept + # Using a lambda function to evaluate which pulses to keep by + # checking the prediction for each pulse + # (Adds the actual pulsemap to dictionary) + if self._total_pulsemap_name not in frame.keys(): + data[ + self._total_pulsemap_name + ] = dataclasses.I3RecoPulseSeriesMapMask( + frame, + self._pulsemap, + lambda om_key, index, pulse: predictions_map[om_key][index] + >= self._threshold, + ) + + # Submit predictions and general pulsemap + frame = self._add_to_frame(frame=frame, data=data) + data = {} + # Adds an additional pulsemap for each DOM type + if isinstance( + self._i3_extractors[0], I3FeatureExtractorIceCubeUpgrade + ): + mDOMMap, DEggMap, IceCubeMap = self._split_pulsemap_in_dom_types( + frame=frame, gcd_file=gcd_file + ) + + if f"{self._total_pulsemap_name}_mDOMs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_mDOMs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(mDOMMap) + + if f"{self._total_pulsemap_name}_dEggs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_dEggs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(DEggMap) + + if f"{self._total_pulsemap_name}_pDOMs_Only" not in frame.keys(): + data[ + f"{self._total_pulsemap_name}_pDOMs_Only" + ] = dataclasses.I3RecoPulseSeriesMap(IceCubeMap) + + # Submits the additional pulsemaps to the frame + frame = self._add_to_frame(frame=frame, data=data) + + return True + + def _split_pulsemap_in_dom_types( + self, frame: I3Frame, gcd_file: Any + ) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: + """Will split the cleaned pulsemap into multiple pulsemaps. + + Arguments: + frame: I3Frame (physics) + gcd_file: path to associated gcd file + + Returns: + mDOMMap, DeGGMap, IceCubeMap + """ + g = dataio.I3File(gcd_file) + gFrame = g.pop_frame() + while "I3Geometry" not in gFrame.keys(): + gFrame = g.pop_frame() + omGeoMap = gFrame["I3Geometry"].omgeo + + mDOMMap, DEggMap, IceCubeMap = {}, {}, {} + pulses = dataclasses.I3RecoPulseSeriesMap.from_frame( + frame, self._total_pulsemap_name + ) + for P in pulses: + om = omGeoMap[P[0]] + if om.omtype == 130: # "mDOM" + mDOMMap[P[0]] = P[1] + elif om.omtype == 120: # "DEgg" + DEggMap[P[0]] = P[1] + elif om.omtype == 20: # "IceCube / pDOM" + IceCubeMap[P[0]] = P[1] + return mDOMMap, DEggMap, IceCubeMap + + def _construct_prediction_map( + self, frame: I3Frame, predictions: np.ndarray + ) -> I3MapKeyVectorDouble: + """Make a pulsemap from predictions (for all OM types). + + Arguments: + frame: I3Frame (physics) + predictions: predictions from GNN + + Returns: + predictions_map: a pulsemap from predictions + """ + pulsemap = dataclasses.I3RecoPulseSeriesMap.from_frame( + frame, self._pulsemap + ) + + idx = 0 + predictions = predictions.squeeze(1) + predictions_map = dataclasses.I3MapKeyVectorDouble() + for om_key, pulses in pulsemap.items(): + num_pulses = len(pulses) + predictions_map[om_key] = predictions[ + idx : idx + num_pulses + ].tolist() + idx += num_pulses + + # Checks + assert idx == len( + predictions + ), """Not all predictions were mapped to pulses,\n + validation of predictions have failed.""" + + assert ( + pulsemap.keys() == predictions_map.keys() + ), """Input pulse map and predictions map do \n + not contain exactly the same OMs""" + return predictions_map
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/coarsening.html b/_modules/graphnet/models/coarsening.html new file mode 100644 index 000000000..6a260f076 --- /dev/null +++ b/_modules/graphnet/models/coarsening.html @@ -0,0 +1,711 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.coarsening — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.coarsening

+"""Class(es) for coarsening operations (i.e., clustering, or local pooling)."""
+
+from abc import abstractmethod
+from typing import List, Optional, Union
+from copy import deepcopy
+import torch
+from torch import LongTensor, Tensor
+from torch_geometric.data import Data, Batch
+from sklearn.cluster import DBSCAN
+
+# from torch_geometric.utils import unbatch_edge_index
+from graphnet.models.components.pool import (
+    group_by,
+    avg_pool,
+    max_pool,
+    min_pool,
+    sum_pool,
+    avg_pool_x,
+    max_pool_x,
+    min_pool_x,
+    sum_pool_x,
+    std_pool_x,
+)
+from graphnet.models import Model
+from graphnet.utilities.config import save_model_config
+
+# Utility method(s)
+from torch_geometric.utils import degree
+
+# NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]
+# TODO:  Remove once bumping to torch_geometric>=2.1.0
+#       See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md]
+
+
+
+[docs] +def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]: + # noqa: D401 + r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. + + Args: + edge_index (Tensor): The edge_index tensor. Must be ordered. + batch (LongTensor): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. Must be ordered. + :rtype: :class:`List[Tensor]` + """ + deg = degree(batch, dtype=torch.int64) + ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0) + + edge_batch = batch[edge_index[0]] + edge_index = edge_index - ptr[edge_batch] + sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist() + return edge_index.split(sizes, dim=1)
+ + + +
+[docs] +class Coarsening(Model): + """Base class for coarsening operations.""" + + # Class variables + reduce_options = { + "avg": (avg_pool, avg_pool_x), + "min": (min_pool, min_pool_x), + "max": (max_pool, max_pool_x), + "sum": (sum_pool, sum_pool_x), + } + + @save_model_config + def __init__( + self, + reduce: str = "avg", + transfer_attributes: bool = True, + ): + """Construct `Coarsening`.""" + assert reduce in self.reduce_options + + ( + self._reduce_method, + self._attribute_reduce_method, + ) = self.reduce_options[reduce] + self._do_transfer_attributes = transfer_attributes + + # Base class constructor + super().__init__() + + @abstractmethod + def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor: + """Cluster nodes in `data` by assigning a cluster index to each.""" + + def _additional_features(self, cluster: LongTensor, data: Batch) -> Tensor: + """Perform additional poolings of feature tensor `x` on `data`. + + By default the nominal `pooling_method` is used for features as well. + This method can be overwritten for bespoke coarsening operations. + """ + + def _transfer_attributes( + self, cluster: LongTensor, original_data: Batch, pooled_data: Batch + ) -> Batch: + """Transfer attributes on `original_data` to `pooled_data`.""" + # Check(s) + if not self._do_transfer_attributes: + return pooled_data + + attributes = list(original_data._store.keys()) + batch: Optional[LongTensor] = original_data.batch + for ix, attr in enumerate(attributes): + if attr not in pooled_data._store: + values: Tensor = getattr(original_data, attr) + + attr_is_node_level_tensor = False + if isinstance(values, Tensor): + if batch is None: + attr_is_node_level_tensor = ( + values.dim() > 1 or values.size(dim=0) > 1 + ) + else: + attr_is_node_level_tensor = ( + values.size() == original_data.batch.size() + ) + + if attr_is_node_level_tensor: + values = self._attribute_reduce_method( + cluster, + values, + batch=torch.zeros_like(values, dtype=torch.int32), + )[0] + + setattr(pooled_data, attr, values) + + return pooled_data + +
+[docs] + def forward(self, data: Union[Data, Batch]) -> Union[Data, Batch]: + """Perform coarsening operation.""" + # Get tensor of cluster indices for each node. + cluster: LongTensor = self._perform_clustering(data) + + # Check whether a graph has already been built. Otherwise, set a dummy + # connectivity, as this is required by pooling functions. + edge_index = data.edge_index + if edge_index is None: + data.edge_index = torch.tensor([[]], dtype=torch.int64) + + # Pool `data` object, including `x`, `batch`. and `edge_index`. + pooled_data: Batch = self._reduce_method(cluster, data) + + # Optionally overwrite feature tensor + x = self._additional_features(cluster, data) + if x is not None: + pooled_data.x = torch.cat( + ( + pooled_data.x, + x, + ), + dim=1, + ) + + # Reset `edge_index` if necessary. + if edge_index is None: + data.edge_index = edge_index + pooled_data.edge_index = edge_index + + # Transfer attributes on `data`, pooling as required. + pooled_data = self._transfer_attributes(cluster, data, pooled_data) + + # Reconstruct Batch Attributes + if isinstance(data, Batch): # if a Batch object + pooled_data = self._reconstruct_batch(data, pooled_data) + return pooled_data
+ + + def _reconstruct_batch(self, original: Data, pooled: Data) -> Data: + pooled = self._add_slice_dict(original, pooled) + pooled = self._add_inc_dict(original, pooled) + return pooled + + def _add_slice_dict(self, original: Data, pooled: Data) -> Data: + # Copy original slice_dict and count nodes in each graph in pooled batch + slice_dict = deepcopy(original._slice_dict) + _, counts = torch.unique_consecutive(pooled.batch, return_counts=True) + # Reconstruct the entry in slice_dict for pulsemaps - only these are affected by pooling + pulsemap_slice = [0] + for i in range(len(counts)): + pulsemap_slice.append(pulsemap_slice[i] + counts[i].item()) + + # Identifies pulsemap entries in slice_dict and set them to pulsemap_slice + for field in slice_dict.keys(): + if (original._num_graphs) == slice_dict[field][-1]: + pass # not pulsemap, so skip + else: + slice_dict[field] = pulsemap_slice + pooled._slice_dict = slice_dict + return pooled + + def _add_inc_dict(self, original: Data, pooled: Data) -> Data: + # not changed by coarsening + pooled._inc_dict = deepcopy(original._inc_dict) + return pooled
+ + + +
+[docs] +class AttributeCoarsening(Coarsening): + """Coarsen pulses based on specified attributes.""" + + @save_model_config + def __init__( + self, + attributes: List[str], + reduce: str = "avg", + transfer_attributes: bool = True, + ): + """Construct `SimpleCoarsening`.""" + self._attributes = attributes + + # Base class constructor + super().__init__(reduce, transfer_attributes) + + def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor: + """Cluster nodes in `data` by assigning a cluster index to each.""" + dom_index = group_by(data, self._attributes) + return dom_index
+ + + +
+[docs] +class DOMCoarsening(Coarsening): + """Coarsen pulses to DOM-level.""" + + def __init__( + self, + reduce: str = "avg", + transfer_attributes: bool = True, + keys: Optional[List[str]] = None, + ): + """Cluster pulses on the same DOM.""" + super().__init__(reduce, transfer_attributes) + if keys is None: + self._keys = [ + "dom_x", + "dom_y", + "dom_z", + "rde", + "pmt_area", + ] + else: + self._keys = keys + + def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor: + """Cluster nodes in `data` by assigning a cluster index to each.""" + dom_index = group_by(data, self._keys) + return dom_index
+ + + +
+[docs] +class CustomDOMCoarsening(DOMCoarsening): + """Coarsen pulses to DOM-level with additional attributes.""" + + def _additional_features(self, cluster: LongTensor, data: Data) -> Tensor: + """Perform Additional poolings of feature tensor `x` on `data`.""" + batch = data.batch + + features = data.features + if batch is not None: + features = [feats[0] for feats in features] + + ix_time = features.index("dom_time") + ix_charge = features.index("charge") + + time = data.x[:, ix_time] + charge = data.x[:, ix_charge] + + x = torch.stack( + ( + min_pool_x(cluster, time, batch)[0], + max_pool_x(cluster, time, batch)[0], + std_pool_x(cluster, time, batch)[0], + min_pool_x(cluster, charge, batch)[0], + max_pool_x(cluster, charge, batch)[0], + std_pool_x(cluster, charge, batch)[0], + sum_pool_x(cluster, torch.ones_like(charge), batch)[ + 0 + ], # Num. nodes (pulses) per cluster (DOM) + ), + dim=1, + ) + + return x
+ + + +
+[docs] +class DOMAndTimeWindowCoarsening(Coarsening): + """Coarsen pulses to DOM-level, with additional time-window clustering.""" + + def __init__( + self, + time_window: float, + reduce: str = "avg", + transfer_attributes: bool = True, + keys: List[str] = [ + "dom_x", + "dom_y", + "dom_z", + "rde", + "pmt_area", + ], + time_key: str = "dom_time", + ): + """Cluster pulses on the same DOM within `time_window`.""" + super().__init__(reduce, transfer_attributes) + self._time_window = time_window + self._cluster_method = DBSCAN(self._time_window, min_samples=1) + self._keys = keys + self._time_key = time_key + + def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor: + """Cluster nodes in `data` by assigning a cluster index to each.""" + dom_index = group_by(data, self._keys) + if data.batch is not None: + features = data.features[0] + else: + features = data.features + + ix_time = features.index(self._time_key) + hit_times = data.x[:, ix_time] + + # Scale up dom_index to make sure clusters are well separated + times_and_domids = torch.stack( + [ + hit_times, + dom_index * self._time_window * 10, + ] + ).T + clusters = torch.tensor( + self._cluster_method.fit_predict(times_and_domids.cpu()), + device=hit_times.device, + ) + + return clusters
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/components/layers.html b/_modules/graphnet/models/components/layers.html new file mode 100644 index 000000000..555ac857b --- /dev/null +++ b/_modules/graphnet/models/components/layers.html @@ -0,0 +1,579 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.components.layers — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.components.layers

+"""Class(es) implementing layers to be used in `graphnet` models."""
+
+from typing import Any, Callable, Optional, Sequence, Union, List, Tuple
+
+import torch
+from torch.functional import Tensor
+from torch_geometric.nn import EdgeConv
+from torch_geometric.nn.pool import knn_graph
+from torch_geometric.typing import Adj, PairTensor
+from torch_geometric.nn.conv import MessagePassing
+from torch_geometric.nn.inits import reset
+from torch.nn.modules import TransformerEncoder, TransformerEncoderLayer
+from torch.nn.modules.normalization import LayerNorm
+from torch_geometric.utils import to_dense_batch
+from pytorch_lightning import LightningModule
+
+
+
+[docs] +class DynEdgeConv(EdgeConv, LightningModule): + """Dynamical edge convolution layer.""" + + def __init__( + self, + nn: Callable, + aggr: str = "max", + nb_neighbors: int = 8, + features_subset: Optional[Union[Sequence[int], slice]] = None, + **kwargs: Any, + ): + """Construct `DynEdgeConv`. + + Args: + nn: The MLP/torch.Module to be used within the `EdgeConv`. + aggr: Aggregation method to be used with `EdgeConv`. + nb_neighbors: Number of neighbours to be clustered after the + `EdgeConv` operation. + features_subset: Subset of features in `Data.x` that should be used + when dynamically performing the new graph clustering after the + `EdgeConv` operation. Defaults to all features. + **kwargs: Additional features to be passed to `EdgeConv`. + """ + # Check(s) + if features_subset is None: + features_subset = slice(None) # Use all features + assert isinstance(features_subset, (list, slice)) + + # Base class constructor + super().__init__(nn=nn, aggr=aggr, **kwargs) + + # Additional member variables + self.nb_neighbors = nb_neighbors + self.features_subset = features_subset + +
+[docs] + def forward( + self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None + ) -> Tensor: + """Forward pass.""" + # Standard EdgeConv forward pass + x = super().forward(x, edge_index) + + # Recompute adjacency + edge_index = knn_graph( + x=x[:, self.features_subset], + k=self.nb_neighbors, + batch=batch, + ).to(self.device) + + return x, edge_index
+
+ + + +
+[docs] +class EdgeConvTito(MessagePassing, LightningModule): + """Implementation of EdgeConvTito layer used in TITO solution for. + + 'IceCube - Neutrinos in Deep' kaggle competition. + """ + + def __init__( + self, + nn: Callable, + aggr: str = "max", + **kwargs: Any, + ): + """Construct `EdgeConvTito`. + + Args: + nn: The MLP/torch.Module to be used within the `EdgeConvTito`. + aggr: Aggregation method to be used with `EdgeConvTito`. + **kwargs: Additional features to be passed to `EdgeConvTito`. + """ + super().__init__(aggr=aggr, **kwargs) + self.nn = nn + self.reset_parameters() + +
+[docs] + def reset_parameters(self) -> None: + """Reset all learnable parameters of the module.""" + reset(self.nn)
+ + +
+[docs] + def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor: + """Forward pass.""" + if isinstance(x, Tensor): + x = (x, x) + # propagate_type: (x: PairTensor) + return self.propagate(edge_index, x=x, size=None)
+ + +
+[docs] + def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: + """Edgeconvtito message passing.""" + return self.nn( + torch.cat([x_i, x_j - x_i, x_j], dim=-1) + ) # EdgeConvTito
+ + + def __repr__(self) -> str: + """Print out module name.""" + return f"{self.__class__.__name__}(nn={self.nn})"
+ + + +
+[docs] +class DynTrans(EdgeConvTito, LightningModule): + """Implementation of dynTrans1 layer used in TITO solution for. + + 'IceCube - Neutrinos in Deep' kaggle competition. + """ + + def __init__( + self, + layer_sizes: Optional[List[int]] = None, + aggr: str = "max", + features_subset: Optional[Union[Sequence[int], slice]] = None, + n_head: int = 8, + **kwargs: Any, + ): + """Construct `DynTrans`. + + Args: + nn: The MLP/torch.Module to be used within the `DynTrans`. + layer_sizes: List of layer sizes to be used in `DynTrans`. + aggr: Aggregation method to be used with `DynTrans`. + features_subset: Subset of features in `Data.x` that should be used + when dynamically performing the new graph clustering after the + `EdgeConv` operation. Defaults to all features. + n_head: Number of heads to be used in the multiheadattention models. + **kwargs: Additional features to be passed to `DynTrans`. + """ + # Check(s) + if features_subset is None: + features_subset = slice(None) # Use all features + assert isinstance(features_subset, (list, slice)) + + if layer_sizes is None: + layer_sizes = [256, 256, 256] + layers = [] + for ix, (nb_in, nb_out) in enumerate( + zip(layer_sizes[:-1], layer_sizes[1:]) + ): + if ix == 0: + nb_in *= 3 # edgeConv1 + layers.append(torch.nn.Linear(nb_in, nb_out)) + layers.append(torch.nn.LeakyReLU()) + d_model = nb_out + + # Base class constructor + super().__init__(nn=torch.nn.Sequential(*layers), aggr=aggr, **kwargs) + + # Additional member variables + self.features_subset = features_subset + + self.norm1 = LayerNorm(d_model, eps=1e-5) # lNorm + + # Transformer layer(s) + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=n_head, + batch_first=True, + norm_first=False, + ) + self._transformer_encoder = TransformerEncoder( + encoder_layer, num_layers=1 + ) + +
+[docs] + def forward( + self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None + ) -> Tensor: + """Forward pass.""" + x_out = super().forward(x, edge_index) + + if x_out.shape[-1] == x.shape[-1]: + x = x + x_out + else: + x = x_out + + x = self.norm1(x) # lNorm + + # Transformer layer + x, mask = to_dense_batch(x, batch) + x = self._transformer_encoder(x, src_key_padding_mask=~mask) + x = x[mask] + + return x
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/components/pool.html b/_modules/graphnet/models/components/pool.html new file mode 100644 index 000000000..2523dfb69 --- /dev/null +++ b/_modules/graphnet/models/components/pool.html @@ -0,0 +1,656 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.components.pool — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.components.pool

+"""Functions for performing pooling/clustering/coarsening."""
+
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+from torch import LongTensor, Tensor
+from torch_geometric.data import Data, Batch
+from torch_geometric.nn.pool.consecutive import consecutive_cluster
+from torch_geometric.nn.pool.pool import pool_edge, pool_batch, pool_pos
+from torch_scatter import scatter, scatter_std
+
+from torch_geometric.nn.pool import (
+    avg_pool,
+    max_pool,
+    avg_pool_x,
+    max_pool_x,
+)
+
+
+
+[docs] +def min_pool( + cluster: LongTensor, data: Data, transform: Optional[Any] = None +) -> Data: + """Perform min-pooling of `Data`. + + Like `max_pool, just negating `data.x`. + """ + data.x = -data.x + data_pooled = max_pool( + cluster, + data, + transform, + ) + data.x = -data.x + data_pooled.x = -data_pooled.x + return data_pooled
+ + + +
+[docs] +def min_pool_x( + cluster: LongTensor, + x: Tensor, + batch: LongTensor, + size: Optional[int] = None, +) -> Tensor: + """Perform min-pooling of `Tensor`. + + Like `max_pool_x, just negating `x`. + """ + ret = max_pool_x(cluster, -x, batch, size) + if size is None: + return (-ret[0], ret[1]) + else: + return -ret
+ + + +
+[docs] +def sum_pool_and_distribute( + tensor: Tensor, + cluster_index: LongTensor, + batch: Optional[LongTensor] = None, +) -> Tensor: + """Sum-pool values and distribute result to the individual nodes.""" + if batch is None: + batch = torch.zeros(tensor.size(dim=0)).long() + tensor_pooled, _ = sum_pool_x(cluster_index, tensor, batch) + inv, _ = consecutive_cluster(cluster_index) + tensor_unpooled = tensor_pooled[inv] + return tensor_unpooled
+ + + +def _group_identical( + tensor: Tensor, batch: Optional[LongTensor] = None +) -> LongTensor: + """Group rows in `tensor` that are identical. + + Args: + tensor: Tensor of shape [N, F]. + batch: Batch indices, to only group identical rows within batches. + + Returns: + List of group indices, from 0 to num. groups - 1, assigning all + identical rows to the same group. + """ + if batch is not None: + tensor = torch.cat((batch.unsqueeze(dim=1), tensor), dim=1) + return torch.unique(tensor, return_inverse=True, sorted=False, dim=0)[1] + + +
+[docs] +def group_by(data: Union[Data, Batch], keys: List[str]) -> LongTensor: + """Group nodes in `data` that have identical values of `keys`. + + This grouping is done with in each event in case of batching. This allows + for, e.g., assigning the same index to all pulses on the same PMT or DOM in + the same event. This can be used for coarsening graphs, e.g., from pulse- + level to DOM-level by aggregating feature across each group returned by this + method. + + Example: + Given: + data.f1 = [1,1,2,2,2] + data.f2 = [6,7,7,7,8] + Calls: + groupby(data, ['f1']) -> [0, 0, 1, 1, 1] + groupby(data, ['f2']) -> [0, 1, 1, 1, 2] + groupby(data, ['f1', 'f2']) -> [0, 1, 2, 2, 3] + """ + features = [getattr(data, key) for key in keys] + tensor = torch.stack(features).T # .int() @TODO: Required? Use rounding? + batch = getattr(data, "batch", None) + index = _group_identical(tensor, batch) + return index
+ + + +
+[docs] +def group_pulses_to_dom(data: Data) -> Data: + """Group pulses on the same DOM, using DOM and string number.""" + data.dom_index = group_by(data, ["dom_number", "string"]) + return data
+ + + +
+[docs] +def group_pulses_to_pmt(data: Data) -> Data: + """Group pulses on the same PMT, using PMT, DOM, and string number.""" + data.pmt_index = group_by(data, ["pmt_number", "dom_number", "string"]) + return data
+ + + +# Below mirroring `torch_geometric.nn.pool.{avg,max}_pool.py`. +def _sum_pool_x( + cluster: LongTensor, x: Tensor, size: Optional[int] = None +) -> Tensor: + return scatter(x, cluster, dim=0, dim_size=size, reduce="sum") + + +def _std_pool_x( + cluster: LongTensor, x: Tensor, size: Optional[int] = None +) -> Tensor: + return scatter_std(x, cluster, dim=0, dim_size=size, unbiased=False) + + +
+[docs] +def sum_pool_x( + cluster: LongTensor, + x: Tensor, + batch: LongTensor, + size: Optional[int] = None, +) -> Tensor: + r"""Sum-pool node features according to the clustering defined in `cluster`. + + Args: + cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, + \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. + x: Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. + batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, + B-1\}}^N`, which assigns each node to a specific example. + size: The maximum number of clusters in a single + example. This property is useful to obtain a batch-wise dense + representation, *e.g.* for applying FC layers, but should only be + used if the size of the maximum number of clusters per example is + known in advance. + """ + if size is not None: + batch_size = int(batch.max().item()) + 1 + return _sum_pool_x(cluster, x, batch_size * size), None + + cluster, perm = consecutive_cluster(cluster) + x = _sum_pool_x(cluster, x) + batch = pool_batch(perm, batch) + + return x, batch
+ + + +
+[docs] +def std_pool_x( + cluster: LongTensor, + x: Tensor, + batch: LongTensor, + size: Optional[int] = None, +) -> Tensor: + r"""Std-pool node features according to the clustering defined in `cluster`. + + Args: + cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, + \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. + x: Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. + batch: Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, + B-1\}}^N`, which assigns each node to a specific example. + size: The maximum number of clusters in a single + example. This property is useful to obtain a batch-wise dense + representation, *e.g.* for applying FC layers, but should only be + used if the size of the maximum number of clusters per example is + known in advance. + """ + if size is not None: + batch_size = int(batch.max().item()) + 1 + return _std_pool_x(cluster, x, batch_size * size), None + + cluster, perm = consecutive_cluster(cluster) + x = _std_pool_x(cluster, x) + batch = pool_batch(perm, batch) + + return x, batch
+ + + +
+[docs] +def sum_pool( + cluster: LongTensor, data: Data, transform: Optional[Callable] = None +) -> Data: + r"""Pool and coarsen graph according to the clustering defined in `cluster`. + + All nodes within the same cluster will be represented as one node. + Final node features are defined by the *sum* of features of all nodes + within the same cluster, node positions are averaged and edge indices are + defined to be the union of the edge indices of all nodes within the same + cluster. + + Args: + cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, + \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. + data: Graph data object. + transform: A function/transform that takes in the + coarsened and pooled :obj:`torch_geometric.data.Data` object and + returns a transformed version. + """ + cluster, perm = consecutive_cluster(cluster) + + x = None if data.x is None else _sum_pool_x(cluster, data.x) + index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) + batch = None if data.batch is None else pool_batch(perm, data.batch) + pos = None if data.pos is None else pool_pos(cluster, data.pos) + + data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) + + if transform is not None: + data = transform(data) + + return data
+ + + +
+[docs] +def std_pool( + cluster: LongTensor, data: Data, transform: Optional[Callable] = None +) -> Data: + r"""Pool and coarsen graph according to the clustering defined in `cluster`. + + All nodes within the same cluster will be represented as one node. + Final node features are defined by the *std* of features of all nodes + within the same cluster, node positions are averaged and edge indices are + defined to be the union of the edge indices of all nodes within the same + cluster. + + Args: + cluster: Cluster vector :math:`\mathbf{c} \in \{ 0, + \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. + data: Graph data object. + transform: A function/transform that takes in the + coarsened and pooled :obj:`torch_geometric.data.Data` object and + returns a transformed version. + """ + cluster, perm = consecutive_cluster(cluster) + + x = None if data.x is None else _std_pool_x(cluster, data.x) + index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) + batch = None if data.batch is None else pool_batch(perm, data.batch) + pos = None if data.pos is None else pool_pos(cluster, data.pos) + + data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) + + if transform is not None: + data = transform(data) + + return data
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/detector/detector.html b/_modules/graphnet/models/detector/detector.html new file mode 100644 index 000000000..eef62aefc --- /dev/null +++ b/_modules/graphnet/models/detector/detector.html @@ -0,0 +1,421 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.detector.detector — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.detector.detector

+"""Base detector-specific `Model` class(es)."""
+
+from abc import abstractmethod
+from typing import Dict, Callable, List
+
+from torch_geometric.data import Data
+import torch
+
+from graphnet.models import Model
+from graphnet.utilities.decorators import final
+from graphnet.utilities.config import save_model_config
+
+
+
+[docs] +class Detector(Model): + """Base class for all detector-specific read-ins in graphnet.""" + + @save_model_config + def __init__(self) -> None: + """Construct `Detector`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + +
+[docs] + @abstractmethod + def feature_map(self) -> Dict[str, Callable]: + """List of features used/assumed by inheriting `Detector` objects."""
+ + +
+[docs] + @final + def forward( # type: ignore + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: + """Pre-process graph `Data` features and build graph adjacency.""" + return self._standardize(node_features, node_feature_names)
+ + + @final + def _standardize( + self, node_features: torch.tensor, node_feature_names: List[str] + ) -> Data: + for idx, feature in enumerate(node_feature_names): + try: + node_features[:, idx] = self.feature_map()[feature]( # type: ignore + node_features[:, idx] + ) + except KeyError as e: + self.warning( + f"""No Standardization function found for '{feature}'""" + ) + raise e + return node_features + + def _identity(self, x: torch.tensor) -> torch.tensor: + """Apply no standardization to input.""" + return x
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/detector/icecube.html b/_modules/graphnet/models/detector/icecube.html new file mode 100644 index 000000000..76d9a8e16 --- /dev/null +++ b/_modules/graphnet/models/detector/icecube.html @@ -0,0 +1,528 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.detector.icecube — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.detector.icecube

+"""IceCube-specific `Detector` class(es)."""
+
+from typing import Dict, Callable
+import torch
+
+from graphnet.models.detector.detector import Detector
+
+
+
+[docs] +class IceCube86(Detector): + """`Detector` class for IceCube-86.""" + +
+[docs] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension of input data.""" + feature_map = { + "dom_x": self._dom_xyz, + "dom_y": self._dom_xyz, + "dom_z": self._dom_xyz, + "dom_time": self._dom_time, + "charge": self._charge, + "rde": self._rde, + "pmt_area": self._pmt_area, + } + return feature_map
+ + + def _dom_xyz(self, x: torch.tensor) -> torch.tensor: + return x / 500.0 + + def _dom_time(self, x: torch.tensor) -> torch.tensor: + return (x - 1.0e04) / 3.0e4 + + def _charge(self, x: torch.tensor) -> torch.tensor: + return torch.log10(x) + + def _rde(self, x: torch.tensor) -> torch.tensor: + return (x - 1.25) / 0.25 + + def _pmt_area(self, x: torch.tensor) -> torch.tensor: + return x / 0.05
+ + + +
+[docs] +class IceCubeKaggle(Detector): + """`Detector` class for Kaggle Competition.""" + +
+[docs] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension of input data.""" + feature_map = { + "x": self._xyz, + "y": self._xyz, + "z": self._xyz, + "time": self._time, + "charge": self._charge, + "auxiliary": self._identity, + } + return feature_map
+ + + def _xyz(self, x: torch.tensor) -> torch.tensor: + return x / 500.0 + + def _time(self, x: torch.tensor) -> torch.tensor: + return (x - 1.0e04) / 3.0e4 + + def _charge(self, x: torch.tensor) -> torch.tensor: + return torch.log10(x) / 3.0
+ + + +
+[docs] +class IceCubeDeepCore(Detector): + """`Detector` class for IceCube-DeepCore.""" + +
+[docs] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension of input data.""" + feature_map = { + "dom_x": self._dom_xy, + "dom_y": self._dom_xy, + "dom_z": self._dom_z, + "dom_time": self._dom_time, + "charge": self._identity, + "rde": self._rde, + "pmt_area": self._pmt_area, + } + return feature_map
+ + + def _dom_xy(self, x: torch.tensor) -> torch.tensor: + return x / 100.0 + + def _dom_z(self, x: torch.tensor) -> torch.tensor: + return (x + 350.0) / 100.0 + + def _dom_time(self, x: torch.tensor) -> torch.tensor: + return ((x / 1.05e04) - 1.0) * 20.0 + + def _rde(self, x: torch.tensor) -> torch.tensor: + return (x - 1.25) / 0.25 + + def _pmt_area(self, x: torch.tensor) -> torch.tensor: + return x / 0.05
+ + + +
+[docs] +class IceCubeUpgrade(Detector): + """`Detector` class for IceCube-Upgrade.""" + +
+[docs] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension of input data.""" + feature_map = { + "dom_x": self._dom_xyz, + "dom_y": self._dom_xyz, + "dom_z": self._dom_xyz, + "dom_time": self._dom_time, + "charge": self._charge, + "rde": self._identity, + "pmt_area": self._pmt_area, + "string": self._string, + "pmt_number": self._pmt_number, + "dom_number": self._dom_number, + "pmt_dir_x": self._identity, + "pmt_dir_y": self._identity, + "pmt_dir_z": self._identity, + "dom_type": self._dom_type, + } + + return feature_map
+ + + def _dom_time(self, x: torch.tensor) -> torch.tensor: + return (x / 2e04) - 1.0 + + def _charge(self, x: torch.tensor) -> torch.tensor: + return torch.log10(x) / 2.0 + + def _string(self, x: torch.tensor) -> torch.tensor: + return (x - 50.0) / 50.0 + + def _pmt_number(self, x: torch.tensor) -> torch.tensor: + return x / 20.0 + + def _dom_number(self, x: torch.tensor) -> torch.tensor: + return (x - 60.0) / 60.0 + + def _dom_type(self, x: torch.tensor) -> torch.tensor: + return x / 130.0 + + def _dom_xyz(self, x: torch.tensor) -> torch.tensor: + return x / 500.0 + + def _pmt_area(self, x: torch.tensor) -> torch.tensor: + return x / 0.05
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/detector/prometheus.html b/_modules/graphnet/models/detector/prometheus.html new file mode 100644 index 000000000..7f234af38 --- /dev/null +++ b/_modules/graphnet/models/detector/prometheus.html @@ -0,0 +1,395 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.detector.prometheus — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.detector.prometheus

+"""Prometheus-specific `Detector` class(es)."""
+
+from typing import Dict, Callable
+import torch
+
+from graphnet.models.detector.detector import Detector
+
+
+
+[docs] +class Prometheus(Detector): + """`Detector` class for Prometheus prototype.""" + +
+[docs] + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" + feature_map = { + "sensor_pos_x": self._sensor_pos_xy, + "sensor_pos_y": self._sensor_pos_xy, + "sensor_pos_z": self._sensor_pos_z, + "t": self._t, + } + return feature_map
+ + + def _sensor_pos_xy(self, x: torch.tensor) -> torch.tensor: + return x / 100 + + def _sensor_pos_z(self, x: torch.tensor) -> torch.tensor: + return (x + 350) / 100 + + def _t(self, x: torch.tensor) -> torch.tensor: + return ((x / 1.05e04) - 1.0) * 20.0
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/convnet.html b/_modules/graphnet/models/gnn/convnet.html new file mode 100644 index 000000000..ada586304 --- /dev/null +++ b/_modules/graphnet/models/gnn/convnet.html @@ -0,0 +1,486 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.gnn.convnet — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.gnn.convnet

+"""Implementation of the ConvNet GNN model architecture.
+
+Author: Martin Ha Minh
+"""
+
+import torch
+from torch import Tensor
+from torch.nn import BatchNorm1d, Linear, Dropout
+import torch.nn.functional as F
+from torch_geometric.nn import TAGConv, global_add_pool, global_max_pool
+from torch_geometric.data import Data
+
+from graphnet.utilities.config import save_model_config
+from graphnet.models.gnn.gnn import GNN
+
+
+
+[docs] +class ConvNet(GNN): + """ConvNet (convolutional network) model.""" + + @save_model_config + def __init__( + self, + nb_inputs: int, + nb_outputs: int, + nb_intermediate: int = 128, + dropout_ratio: float = 0.3, + ): + """Construct `ConvNet`. + + Args: + nb_inputs: Number of input features, i.e. dimension of input + layer. + nb_outputs: Number of prediction labels, i.e. dimension of + output layer. + nb_intermediate: Number of nodes in intermediate layer(s). + dropout_ratio: Fraction of nodes to drop. + """ + # Base class constructor + super().__init__(nb_inputs, nb_outputs) + + # Member variables + self.nb_intermediate = nb_intermediate + self.nb_intermediate2 = 6 * self.nb_intermediate + + # Architecture configuration + self.conv1 = TAGConv(self.nb_inputs, self.nb_intermediate, 2) + self.conv2 = TAGConv(self.nb_intermediate, self.nb_intermediate, 2) + self.conv3 = TAGConv(self.nb_intermediate, self.nb_intermediate, 2) + + self.batchnorm1 = BatchNorm1d(self.nb_intermediate2) + + self.linear1 = Linear(self.nb_intermediate2, self.nb_intermediate2) + self.linear2 = Linear(self.nb_intermediate2, self.nb_intermediate2) + self.linear3 = Linear(self.nb_intermediate2, self.nb_intermediate2) + self.linear4 = Linear(self.nb_intermediate2, self.nb_intermediate2) + self.linear5 = Linear(self.nb_intermediate2, self.nb_intermediate2) + + self.drop1 = Dropout(dropout_ratio) + self.drop2 = Dropout(dropout_ratio) + self.drop3 = Dropout(dropout_ratio) + self.drop4 = Dropout(dropout_ratio) + self.drop5 = Dropout(dropout_ratio) + + self.out = Linear(self.nb_intermediate2, self.nb_outputs) + +
+[docs] + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass.""" + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + # Graph convolutional operations + x = F.leaky_relu(self.conv1(x, edge_index)) + x1 = torch.cat( + [ + global_add_pool(x, batch), + global_max_pool(x, batch), + ], + dim=1, + ) + + x = F.leaky_relu(self.conv2(x, edge_index)) + x2 = torch.cat( + [ + global_add_pool(x, batch), + global_max_pool(x, batch), + ], + dim=1, + ) + + x = F.leaky_relu(self.conv3(x, edge_index)) + x3 = torch.cat( + [ + global_add_pool(x, batch), + global_max_pool(x, batch), + ], + dim=1, + ) + + # Skip-cat + x = torch.cat([x1, x2, x3], dim=1) + + # Batch-normalising intermediate features + x = self.batchnorm1(x) + + # Post-processing + x = F.leaky_relu(self.linear1(x)) + x = self.drop1(x) + x = F.leaky_relu(self.linear2(x)) + x = self.drop2(x) + x = F.leaky_relu(self.linear3(x)) + x = self.drop3(x) + x = F.leaky_relu(self.linear4(x)) + x = self.drop4(x) + x = F.leaky_relu(self.linear5(x)) + x = self.drop5(x) + + # Read-out + x = self.out(x) + + return x
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge.html b/_modules/graphnet/models/gnn/dynedge.html new file mode 100644 index 000000000..d882546a3 --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge.html @@ -0,0 +1,693 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.gnn.dynedge — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.gnn.dynedge

+"""Implementation of the DynEdge GNN model architecture."""
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+from torch import Tensor, LongTensor
+from torch_geometric.data import Data
+from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
+
+from graphnet.models.components.layers import DynEdgeConv
+from graphnet.utilities.config import save_model_config
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.utils import calculate_xyzt_homophily
+
+GLOBAL_POOLINGS = {
+    "min": scatter_min,
+    "max": scatter_max,
+    "sum": scatter_sum,
+    "mean": scatter_mean,
+}
+
+
+
+[docs] +class DynEdge(GNN): + """DynEdge (dynamical edge convolutional) model.""" + + @save_model_config + def __init__( + self, + nb_inputs: int, + *, + nb_neighbours: int = 8, + features_subset: Optional[Union[List[int], slice]] = None, + dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None, + post_processing_layer_sizes: Optional[List[int]] = None, + readout_layer_sizes: Optional[List[int]] = None, + global_pooling_schemes: Optional[Union[str, List[str]]] = None, + add_global_variables_after_pooling: bool = False, + ): + """Construct `DynEdge`. + + Args: + nb_inputs: Number of input features on each node. + nb_neighbours: Number of neighbours to used in the k-nearest + neighbour clustering which is performed after each (dynamical) + edge convolution. + features_subset: The subset of latent features on each node that + are used as metric dimensions when performing the k-nearest + neighbours clustering. Defaults to [0,1,2]. + dynedge_layer_sizes: The layer sizes, or latent feature dimenions, + used in the `DynEdgeConv` layer. Each entry in + `dynedge_layer_sizes` corresponds to a single `DynEdgeConv` + layer; the integers in the corresponding tuple corresponds to + the layer sizes in the multi-layer perceptron (MLP) that is + applied within each `DynEdgeConv` layer. That is, a list of + size-two tuples means that all `DynEdgeConv` layers contain a + two-layer MLP. + Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)]. + post_processing_layer_sizes: Hidden layer sizes in the MLP + following the skip-concatenation of the outputs of each + `DynEdgeConv` layer. Defaults to [336, 256]. + readout_layer_sizes: Hidden layer sizes in the MLP following the + post-processing _and_ optional global pooling. As this is the + last layer(s) in the model, the last layer in the read-out + yields the output of the `DynEdge` model. Defaults to [128,]. + global_pooling_schemes: The list global pooling schemes to use. + Options are: "min", "max", "mean", and "sum". + add_global_variables_after_pooling: Whether to add global variables + after global pooling. The alternative is to added (distribute) + them to the individual nodes before any convolutional + operations. + """ + # Latent feature subset for computing nearest neighbours in DynEdge. + if features_subset is None: + features_subset = slice(0, 3) + + # DynEdge layer sizes + if dynedge_layer_sizes is None: + dynedge_layer_sizes = [ + ( + 128, + 256, + ), + ( + 336, + 256, + ), + ( + 336, + 256, + ), + ( + 336, + 256, + ), + ] + + assert isinstance(dynedge_layer_sizes, list) + assert len(dynedge_layer_sizes) + assert all(isinstance(sizes, tuple) for sizes in dynedge_layer_sizes) + assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes) + assert all( + all(size > 0 for size in sizes) for sizes in dynedge_layer_sizes + ) + + self._dynedge_layer_sizes = dynedge_layer_sizes + + # Post-processing layer sizes + if post_processing_layer_sizes is None: + post_processing_layer_sizes = [ + 336, + 256, + ] + + assert isinstance(post_processing_layer_sizes, list) + assert len(post_processing_layer_sizes) + assert all(size > 0 for size in post_processing_layer_sizes) + + self._post_processing_layer_sizes = post_processing_layer_sizes + + # Read-out layer sizes + if readout_layer_sizes is None: + readout_layer_sizes = [ + 128, + ] + + assert isinstance(readout_layer_sizes, list) + assert len(readout_layer_sizes) + assert all(size > 0 for size in readout_layer_sizes) + + self._readout_layer_sizes = readout_layer_sizes + + # Global pooling scheme(s) + if isinstance(global_pooling_schemes, str): + global_pooling_schemes = [global_pooling_schemes] + + if isinstance(global_pooling_schemes, list): + for pooling_scheme in global_pooling_schemes: + assert ( + pooling_scheme in GLOBAL_POOLINGS + ), f"Global pooling scheme {pooling_scheme} not supported." + else: + assert global_pooling_schemes is None + + self._global_pooling_schemes = global_pooling_schemes + + if add_global_variables_after_pooling: + assert self._global_pooling_schemes, ( + "No global pooling schemes were request, so cannot add global" + " variables after pooling." + ) + self._add_global_variables_after_pooling = ( + add_global_variables_after_pooling + ) + + # Base class constructor + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + # Remaining member variables() + self._activation = torch.nn.LeakyReLU() + self._nb_inputs = nb_inputs + self._nb_global_variables = 5 + nb_inputs + self._nb_neighbours = nb_neighbours + self._features_subset = features_subset + + self._construct_layers() + + def _construct_layers(self) -> None: + """Construct layers (torch.nn.Modules).""" + # Convolutional operations + nb_input_features = self._nb_inputs + if not self._add_global_variables_after_pooling: + nb_input_features += self._nb_global_variables + + self._conv_layers = torch.nn.ModuleList() + nb_latent_features = nb_input_features + for sizes in self._dynedge_layer_sizes: + layers = [] + layer_sizes = [nb_latent_features] + list(sizes) + for ix, (nb_in, nb_out) in enumerate( + zip(layer_sizes[:-1], layer_sizes[1:]) + ): + if ix == 0: + nb_in *= 2 + layers.append(torch.nn.Linear(nb_in, nb_out)) + layers.append(self._activation) + + conv_layer = DynEdgeConv( + torch.nn.Sequential(*layers), + aggr="add", + nb_neighbors=self._nb_neighbours, + features_subset=self._features_subset, + ) + self._conv_layers.append(conv_layer) + + nb_latent_features = nb_out + + # Post-processing operations + nb_latent_features = ( + sum(sizes[-1] for sizes in self._dynedge_layer_sizes) + + nb_input_features + ) + + post_processing_layers = [] + layer_sizes = [nb_latent_features] + list( + self._post_processing_layer_sizes + ) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + post_processing_layers.append(self._activation) + + self._post_processing = torch.nn.Sequential(*post_processing_layers) + + # Read-out operations + nb_poolings = ( + len(self._global_pooling_schemes) + if self._global_pooling_schemes + else 1 + ) + nb_latent_features = nb_out * nb_poolings + if self._add_global_variables_after_pooling: + nb_latent_features += self._nb_global_variables + + readout_layers = [] + layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + readout_layers.append(torch.nn.Linear(nb_in, nb_out)) + readout_layers.append(self._activation) + + self._readout = torch.nn.Sequential(*readout_layers) + + def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: + """Perform global pooling.""" + assert self._global_pooling_schemes + pooled = [] + for pooling_scheme in self._global_pooling_schemes: + pooling_fn = GLOBAL_POOLINGS[pooling_scheme] + pooled_x = pooling_fn(x, index=batch, dim=0) + if isinstance(pooled_x, tuple) and len(pooled_x) == 2: + # `scatter_{min,max}`, which return also an argument, vs. + # `scatter_{mean,sum}` + pooled_x, _ = pooled_x + pooled.append(pooled_x) + + return torch.cat(pooled, dim=1) + + def _calculate_global_variables( + self, + x: Tensor, + edge_index: LongTensor, + batch: LongTensor, + *additional_attributes: Tensor, + ) -> Tensor: + """Calculate global variables.""" + # Calculate homophily (scalar variables) + h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) + + # Calculate mean features + global_means = scatter_mean(x, batch, dim=0) + + # Add global variables + global_variables = torch.cat( + [ + global_means, + h_x, + h_y, + h_z, + h_t, + ] + + [attr.unsqueeze(dim=1) for attr in additional_attributes], + dim=1, + ) + + return global_variables + +
+[docs] + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass.""" + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + global_variables = self._calculate_global_variables( + x, + edge_index, + batch, + torch.log10(data.n_pulses), + ) + + # Distribute global variables out to each node + if not self._add_global_variables_after_pooling: + distribute = ( + batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0) + ).type(torch.float) + + global_variables_distributed = torch.sum( + distribute.unsqueeze(dim=2) + * global_variables.unsqueeze(dim=0), + dim=1, + ) + + x = torch.cat((x, global_variables_distributed), dim=1) + + # DynEdge-convolutions + skip_connections = [x] + for conv_layer in self._conv_layers: + x, edge_index = conv_layer(x, edge_index, batch) + skip_connections.append(x) + + # Skip-cat + x = torch.cat(skip_connections, dim=1) + + # Post-processing + x = self._post_processing(x) + + # (Optional) Global pooling + if self._global_pooling_schemes: + x = self._global_pooling(x, batch=batch) + if self._add_global_variables_after_pooling: + x = torch.cat( + [ + x, + global_variables, + ], + dim=1, + ) + + # Read-out + x = self._readout(x) + + return x
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge_jinst.html b/_modules/graphnet/models/gnn/dynedge_jinst.html new file mode 100644 index 000000000..125dc1399 --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge_jinst.html @@ -0,0 +1,521 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.gnn.dynedge_jinst — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.gnn.dynedge_jinst

+"""Implementation of the exact DynEdge architecture used in [2209.03042].
+
+Author: Rasmus Oersoe
+"""
+from typing import Optional
+
+import torch
+from torch import Tensor
+from torch_geometric.data import Data
+from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
+
+from graphnet.models.components.layers import DynEdgeConv
+from graphnet.utilities.config import save_model_config
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.utils import calculate_xyzt_homophily
+
+
+
+[docs] +class DynEdgeJINST(GNN): + """DynEdge (dynamical edge convolutional) model used in [2209.03042].""" + + @save_model_config + def __init__( + self, + nb_inputs: int, + layer_size_scale: int = 4, + ): + """Construct `DynEdgeJINST`. + + Args: + nb_inputs: Number of input features. + nb_outputs: Number of output features. + layer_size_scale: Integer that scales the size of hidden layers. + """ + # Architecture configuration + c = layer_size_scale + l1, l2, l3, l4, l5, l6 = ( + nb_inputs, + c * 16 * 2, + c * 32 * 2, + c * 42 * 2, + c * 32 * 2, + c * 16 * 2, + ) + + # Base class constructor + super().__init__(nb_inputs, l6) + + # Graph convolutional operations + features_subset = slice(0, 3) + nb_neighbors = 8 + + self.conv_add1 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l1 * 2, l2), + torch.nn.LeakyReLU(), + torch.nn.Linear(l2, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add2 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add3 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + self.conv_add4 = DynEdgeConv( + torch.nn.Sequential( + torch.nn.Linear(l3 * 2, l4), + torch.nn.LeakyReLU(), + torch.nn.Linear(l4, l3), + torch.nn.LeakyReLU(), + ), + aggr="add", + nb_neighbors=nb_neighbors, + features_subset=features_subset, + ) + + # Post-processing operations + self.nn1 = torch.nn.Linear(l3 * 4 + l1, l4) + self.nn2 = torch.nn.Linear(l4, l5) + self.nn3 = torch.nn.Linear(4 * l5 + 5, l6) + self.lrelu = torch.nn.LeakyReLU() + +
+[docs] + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass.""" + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + # Calculate homophily (scalar variables) + h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) + + a, edge_index = self.conv_add1(x, edge_index, batch) + b, edge_index = self.conv_add2(a, edge_index, batch) + c, edge_index = self.conv_add3(b, edge_index, batch) + d, edge_index = self.conv_add4(c, edge_index, batch) + + # Skip-cat + x = torch.cat((x, a, b, c, d), dim=1) + + # Post-processing + x = self.nn1(x) + x = self.lrelu(x) + x = self.nn2(x) + + # Aggregation across nodes + a, _ = scatter_max(x, batch, dim=0) + b, _ = scatter_min(x, batch, dim=0) + c = scatter_sum(x, batch, dim=0) + d = scatter_mean(x, batch, dim=0) + + # Concatenate aggregations and scalar features + x = torch.cat( + ( + a, + b, + c, + d, + h_t.reshape(-1, 1), + h_x.reshape(-1, 1), + h_y.reshape(-1, 1), + h_z.reshape(-1, 1), + data.n_pulses.reshape(-1, 1), + ), + dim=1, + ) + + # Read-out + x = self.lrelu(x) + x = self.nn3(x) + + x = self.lrelu(x) + + return x
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html new file mode 100644 index 000000000..8895fa50c --- /dev/null +++ b/_modules/graphnet/models/gnn/dynedge_kaggle_tito.html @@ -0,0 +1,618 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.gnn.dynedge_kaggle_tito — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.gnn.dynedge_kaggle_tito

+"""Implementation of DynEdge architecture used in.
+
+                    IceCube - Neutrinos in Deep Ice
+Reconstruct the direction of neutrinos from the Universe to the South Pole
+
+Kaggle competition.
+
+Solution by TITO.
+"""
+
+from typing import List, Tuple, Optional
+
+import torch
+from torch import Tensor, LongTensor
+
+from torch_geometric.data import Data
+from torch_geometric.utils import to_dense_batch
+from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
+
+from graphnet.models.components.layers import DynTrans
+from graphnet.utilities.config import save_model_config
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.utils import calculate_xyzt_homophily
+
+GLOBAL_POOLINGS = {
+    "min": scatter_min,
+    "max": scatter_max,
+    "sum": scatter_sum,
+    "mean": scatter_mean,
+}
+
+
+
+[docs] +class DynEdgeTITO(GNN): + """DynEdge (dynamical edge convolutional) model.""" + + @save_model_config + def __init__( + self, + nb_inputs: int, + features_subset: slice = slice(0, 4), + dyntrans_layer_sizes: Optional[List[Tuple[int, ...]]] = None, + global_pooling_schemes: List[str] = ["max"], + ): + """Construct `DynEdge`. + + Args: + nb_inputs: Number of input features on each node. + features_subset: The subset of latent features on each node that + are used as metric dimensions when performing the k-nearest + neighbours clustering. Defaults to [0,1,2,3]. + dyntrans_layer_sizes: The layer sizes, or latent feature dimenions, + used in the `DynTrans` layer. + global_pooling_schemes: The list global pooling schemes to use. + Options are: "min", "max", "mean", and "sum". + """ + # DynEdge layer sizes + if dyntrans_layer_sizes is None: + dyntrans_layer_sizes = [ + ( + 256, + 256, + ), + ( + 256, + 256, + ), + ( + 256, + 256, + ), + ] + + assert isinstance(dyntrans_layer_sizes, list) + assert len(dyntrans_layer_sizes) + assert all(isinstance(sizes, tuple) for sizes in dyntrans_layer_sizes) + assert all(len(sizes) > 0 for sizes in dyntrans_layer_sizes) + assert all( + all(size > 0 for size in sizes) for sizes in dyntrans_layer_sizes + ) + + self._dyntrans_layer_sizes = dyntrans_layer_sizes + + # Post-processing layer sizes + post_processing_layer_sizes = [ + 336, + 256, + ] + + self._post_processing_layer_sizes = post_processing_layer_sizes + + # Read-out layer sizes + readout_layer_sizes = [ + 256, + 128, + ] + + self._readout_layer_sizes = readout_layer_sizes + + # Global pooling scheme(s) + if isinstance(global_pooling_schemes, str): + global_pooling_schemes = [global_pooling_schemes] + + if isinstance(global_pooling_schemes, list): + for pooling_scheme in global_pooling_schemes: + assert ( + pooling_scheme in GLOBAL_POOLINGS + ), f"Global pooling scheme {pooling_scheme} not supported." + else: + assert global_pooling_schemes is None + + self._global_pooling_schemes = global_pooling_schemes + + assert self._global_pooling_schemes, ( + "No global pooling schemes were request, so cannot add global" + " variables after pooling." + ) + + # Base class constructor + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) + + # Remaining member variables() + self._activation = torch.nn.LeakyReLU() + self._nb_inputs = nb_inputs + self._nb_global_variables = 5 + nb_inputs + self._features_subset = features_subset + self._construct_layers() + + def _construct_layers(self) -> None: + """Construct layers (torch.nn.Modules).""" + # Convolutional operations + nb_input_features = self._nb_inputs + + self._conv_layers = torch.nn.ModuleList() + nb_latent_features = nb_input_features + for sizes in self._dyntrans_layer_sizes: + conv_layer = DynTrans( + [nb_latent_features] + list(sizes), + aggr="max", + features_subset=self._features_subset, + n_head=8, + ) + self._conv_layers.append(conv_layer) + nb_latent_features = sizes[-1] + + post_processing_layers = [] + layer_sizes = [nb_latent_features] + list( + self._post_processing_layer_sizes + ) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + post_processing_layers.append(torch.nn.Linear(nb_in, nb_out)) + post_processing_layers.append(self._activation) + last_posting_layer_output_dim = nb_out + + self._post_processing = torch.nn.Sequential(*post_processing_layers) + + # Read-out operations + nb_poolings = ( + len(self._global_pooling_schemes) + if self._global_pooling_schemes + else 1 + ) + nb_latent_features = last_posting_layer_output_dim * nb_poolings + nb_latent_features += self._nb_global_variables + + readout_layers = [] + layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): + readout_layers.append(torch.nn.Linear(nb_in, nb_out)) + readout_layers.append(self._activation) + + self._readout = torch.nn.Sequential(*readout_layers) + + def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: + """Perform global pooling.""" + assert self._global_pooling_schemes + pooled = [] + for pooling_scheme in self._global_pooling_schemes: + pooling_fn = GLOBAL_POOLINGS[pooling_scheme] + pooled_x = pooling_fn(x, index=batch, dim=0) + if isinstance(pooled_x, tuple) and len(pooled_x) == 2: + # `scatter_{min,max}`, which return also an argument, vs. + # `scatter_{mean,sum}` + pooled_x, _ = pooled_x + pooled.append(pooled_x) + + return torch.cat(pooled, dim=1) + + def _calculate_global_variables( + self, + x: Tensor, + edge_index: LongTensor, + batch: LongTensor, + *additional_attributes: Tensor, + ) -> Tensor: + """Calculate global variables.""" + # Calculate homophily (scalar variables) + h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) + + # Calculate mean features + global_means = scatter_mean(x, batch, dim=0) + + # Add global variables + global_variables = torch.cat( + [ + global_means, + h_x, + h_y, + h_z, + h_t, + ] + + [attr.unsqueeze(dim=1) for attr in additional_attributes], + dim=1, + ) + + return global_variables + +
+[docs] + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass.""" + # Convenience variables + x, edge_index, batch = data.x, data.edge_index, data.batch + + global_variables = self._calculate_global_variables( + x, + edge_index, + batch, + torch.log10(data.n_pulses), + ) + + # DynEdge-convolutions + for conv_layer in self._conv_layers: + x = conv_layer(x, edge_index, batch) + + x, mask = to_dense_batch(x, batch) + x = x[mask] + + # Post-processing + x = self._post_processing(x) + + # (Optional) Global pooling + x = self._global_pooling(x, batch=batch) + x = torch.cat( + [ + x, + global_variables, + ], + dim=1, + ) + + # Read-out + x = self._readout(x) + + return x
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/gnn/gnn.html b/_modules/graphnet/models/gnn/gnn.html new file mode 100644 index 000000000..7a3a22f46 --- /dev/null +++ b/_modules/graphnet/models/gnn/gnn.html @@ -0,0 +1,403 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.gnn.gnn — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.gnn.gnn

+"""Base GNN-specific `Model` class(es)."""
+
+from abc import abstractmethod
+
+from torch import Tensor
+from torch_geometric.data import Data
+
+from graphnet.models import Model
+from graphnet.utilities.config import save_model_config
+
+
+
+[docs] +class GNN(Model): + """Base class for all core GNN models in graphnet.""" + + @save_model_config + def __init__(self, nb_inputs: int, nb_outputs: int) -> None: + """Construct `GNN`.""" + # Base class constructor + super().__init__() + + # Member variables + self._nb_inputs = nb_inputs + self._nb_outputs = nb_outputs + + @property + def nb_inputs(self) -> int: + """Return number of input features.""" + return self._nb_inputs + + @property + def nb_outputs(self) -> int: + """Return number of output features.""" + return self._nb_outputs + +
+[docs] + @abstractmethod + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass in model."""
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/edges/edges.html b/_modules/graphnet/models/graphs/edges/edges.html new file mode 100644 index 000000000..e681629cf --- /dev/null +++ b/_modules/graphnet/models/graphs/edges/edges.html @@ -0,0 +1,563 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.graphs.edges.edges — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.graphs.edges.edges

+"""Class(es) for building/connecting graphs."""
+
+from typing import List
+from abc import abstractmethod, ABC
+
+import torch
+from torch_geometric.nn import knn_graph, radius_graph
+from torch_geometric.data import Data
+
+from graphnet.utilities.config import save_model_config
+from graphnet.models.utils import calculate_distance_matrix
+from graphnet.models import Model
+
+
+
+[docs] +class EdgeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + +
+[docs] + def forward(self, graph: Data) -> Data: + """Construct edges based on problem specific implementation of. + + ´_construct_edges´ + + Args: + graph: a graph without edges + + Returns: + graph: a graph with edges + """ + if graph.edge_index is not None: + self.warnonce( + "GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + return self._construct_edges(graph)
+ + + @abstractmethod + def _construct_edges(self, graph: Data) -> Data: + """Construct edges and assign them to graph. I.e. ´graph.edge_index = edge_index´. + + Args: + graph: graph without edges + + Returns: + graph: graph with edges assigned. + """
+ + + +
+[docs] +class KNNEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges from the k-nearest neighbours.""" + + @save_model_config + def __init__( + self, + nb_nearest_neighbours: int, + columns: List[int] = [0, 1, 2], + ): + """K-NN Edge definition. + + Will connect nodes together with their ´nb_nearest_neighbours´ + nearest neighbours in the feature space given by ´columns´. + + Args: + nb_nearest_neighbours: number of neighbours. + columns: Node features to use for distance calculation. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._nb_nearest_neighbours = nb_nearest_neighbours + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define K-NN edges.""" + graph.edge_index = knn_graph( + graph.x[:, self._columns], + self._nb_nearest_neighbours, + graph.batch, + ).to(self.device) + + return graph
+ + + +
+[docs] +class RadialEdges(EdgeDefinition): + """Builds graph from a sphere of chosen radius centred at each node.""" + + @save_model_config + def __init__( + self, + radius: float, + columns: List[int] = [0, 1, 2], + ): + """Radial edges. + + Connects each node to other nodes that are within a sphere of + radius ´r´ centered at the node. The feature space of ´r´ is defined + by ´columns´ + + Args: + radius: radius of sphere + columns: columns of the node feature matrix used. + Defaults to [0,1,2]. + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member variable(s) + self._radius = radius + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Define radial edges.""" + graph.edge_index = radius_graph( + graph.x[:, self._columns], + self._radius, + graph.batch, + ).to(self.device) + + return graph
+ + + +
+[docs] +class EuclideanEdges(EdgeDefinition): # pylint: disable=too-few-public-methods + """Builds edges according to Euclidean distance between nodes. + + See https://arxiv.org/pdf/1809.06166.pdf. + """ + + @save_model_config + def __init__( + self, + sigma: float, + threshold: float = 0.0, + columns: List[int] = None, + ): + """Construct `EuclideanEdges`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if columns is None: + columns = [0, 1, 2] + + # Member variable(s) + self._sigma = sigma + self._threshold = threshold + self._columns = columns + + def _construct_edges(self, graph: Data) -> Data: + """Forward pass.""" + # Constructs the adjacency matrix from the raw, DOM-level data and + # returns this matrix + if graph.edge_index is not None: + self.info( + "WARNING: GraphBuilder received graph with pre-existing " + "structure. Will overwrite." + ) + + xyz_coords = graph.x[:, self._columns] + + # Construct block-diagonal matrix indicating whether pulses belong to + # the same event in the batch + batch_mask = graph.batch.unsqueeze(dim=0) == graph.batch.unsqueeze( + dim=1 + ) + + distance_matrix = calculate_distance_matrix(xyz_coords) + affinity_matrix = torch.exp( + -0.5 * distance_matrix**2 / self._sigma**2 + ) + + # Use softmax to normalise all adjacencies to one for each node + exp_row_sums = torch.exp(affinity_matrix).sum(axis=1) + weighted_adj_matrix = torch.exp( + affinity_matrix + ) / exp_row_sums.unsqueeze(dim=1) + + # Only include edges with weights that exceed the chosen threshold (and + # are part of the same event) + sources, targets = torch.where( + (weighted_adj_matrix > self._threshold) & (batch_mask) + ) + edge_weights = weighted_adj_matrix[sources, targets] + + graph.edge_index = torch.stack((sources, targets)) + graph.edge_weight = edge_weights + + return graph
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/graph_definition.html b/_modules/graphnet/models/graphs/graph_definition.html new file mode 100644 index 000000000..75075b762 --- /dev/null +++ b/_modules/graphnet/models/graphs/graph_definition.html @@ -0,0 +1,634 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.graphs.graph_definition — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.graphs.graph_definition

+"""Modules for defining graphs.
+
+These are self-contained graph definitions that hold all the graph-altering
+code in graphnet. These modules define what the GNNs sees as input and can be
+passed to dataloaders during training and deployment.
+"""
+
+
+from typing import Any, List, Optional, Dict, Callable
+import torch
+from torch_geometric.data import Data
+import numpy as np
+
+from graphnet.utilities.config import save_model_config
+
+from graphnet.models.detector import Detector
+from .edges import EdgeDefinition
+from .nodes import NodeDefinition
+from graphnet.models import Model
+
+
+
+[docs] +class GraphDefinition(Model): + """An Abstract class to create graph definitions from.""" + + @save_model_config + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition, + edge_definition: Optional[EdgeDefinition] = None, + node_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + ): + """Construct ´GraphDefinition´. The ´detector´ holds. + + ´Detector´-specific code. E.g. scaling/standardization and geometry + tables. + + ´node_definition´ defines the nodes in the graph. + + ´edge_definition´ defines the connectivity of the nodes in the graph. + + Args: + detector: The corresponding ´Detector´ representing the data. + node_definition: Definition of nodes. + edge_definition: Definition of edges. Defaults to None. + node_feature_names: Names of node feature columns. Defaults to None + dtype: data type used for node features. e.g. ´torch.float´ + """ + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Member Variables + self._detector = detector + self._edge_definition = edge_definition + self._node_definition = node_definition + if node_feature_names is None: + # Assume all features in Detector is used. + node_feature_names = list(self._detector.feature_map().keys()) # type: ignore + self._node_feature_names = node_feature_names + + # Set data type + self.to(dtype) + + # Set Input / Output dimensions + self._node_definition.set_number_of_inputs( + node_feature_names=node_feature_names + ) + self.nb_inputs = len(self._node_feature_names) + self.nb_outputs = self._node_definition.nb_outputs + +
+[docs] + def forward( # type: ignore + self, + node_features: np.ndarray, + node_feature_names: List[str], + truth_dicts: Optional[List[Dict[str, Any]]] = None, + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + data_path: Optional[str] = None, + ) -> Data: + """Construct graph as ´Data´ object. + + Args: + node_features: node features for graph. Shape ´[num_nodes, d]´ + node_feature_names: name of each column. Shape ´[,d]´. + truth_dicts: Dictionary containing truth labels. + custom_label_functions: Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels. + loss_weight_column: Name of column that holds loss weight. Defaults to None. + loss_weight: Loss weight associated with event. Defaults to None. + loss_weight_default_value: default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None. + data_path: Path to dataset data files. Defaults to None. + + Returns: + graph + """ + # Checks + self._validate_input( + node_features=node_features, node_feature_names=node_feature_names + ) + + # Transform to pytorch tensor + node_features = torch.tensor(node_features, dtype=self.dtype) + + # Standardize / Scale node features + node_features = self._detector(node_features, node_feature_names) + + # Create graph + graph = self._node_definition(node_features) + + # Attach number of pulses as static attribute. + graph.n_pulses = torch.tensor(len(node_features), dtype=torch.int32) + + # Assign edges + if self._edge_definition is not None: + graph = self._edge_definition(graph) + else: + self.warnonce( + "No EdgeDefinition provided. Graphs will not have edges defined!" + ) + + # Attach data path - useful for Ensemble datasets. + if data_path is not None: + graph["dataset_path"] = data_path + + # Attach loss weights if they exist + graph = self._add_loss_weights( + graph=graph, + loss_weight=loss_weight, + loss_weight_column=loss_weight_column, + loss_weight_default_value=loss_weight_default_value, + ) + + # Attach default truth labels and node truths + if truth_dicts is not None: + graph = self._add_truth(graph=graph, truth_dicts=truth_dicts) + + # Attach custom truth labels + if custom_label_functions is not None: + graph = self._add_custom_labels( + graph=graph, custom_label_functions=custom_label_functions + ) + + # Attach node features as seperate fields. MAY NOT CONTAIN 'x' + graph = self._add_features_individually( + graph=graph, node_feature_names=node_feature_names + ) + + # Add GraphDefinition Stamp + graph["graph_definition"] = self.__class__.__name__ + return graph
+ + + def _validate_input( + self, node_features: np.array, node_feature_names: List[str] + ) -> None: + # node feature matrix dimension check + assert node_features.shape[1] == len(node_feature_names) + + # check that provided features for input is the same that the ´Graph´ + # was instantiated with. + assert len(node_feature_names) == len( + self._node_feature_names + ), f"""Input features ({node_feature_names}) is not what {self.__class__.__name__} was instatiated with ({self._node_feature_names})""" + for idx in range(len(node_feature_names)): + assert ( + node_feature_names[idx] == self._node_feature_names[idx] + ), f""" Order of node features in data are not the same as expected. Got {node_feature_names} vs. {self._node_feature_names}""" + + def _add_loss_weights( + self, + graph: Data, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + ) -> Data: + """Attempt to store a loss weight in the graph for use during training. + + I.e. `graph[loss_weight_column] = loss_weight` + + Args: + loss_weight: The non-negative weight to be stored. + graph: Data object representing the event. + loss_weight_column: The name under which the weight is stored in + the graph. + loss_weight_default_value: The default value used if + none was retrieved. + + Returns: + A graph with loss weight added, if available. + """ + # Add loss weight to graph. + if loss_weight is not None and loss_weight_column is not None: + # No loss weight was retrieved, i.e., it is missing for the current + # event. + if loss_weight < 0: + if loss_weight_default_value is None: + raise ValueError( + "At least one event is missing an entry in " + f"{loss_weight_column} " + "but loss_weight_default_value is None." + ) + graph[loss_weight_column] = torch.tensor( + self._loss_weight_default_value, dtype=self.dtype + ).reshape(-1, 1) + else: + graph[loss_weight_column] = torch.tensor( + loss_weight, dtype=self.dtype + ).reshape(-1, 1) + return graph + + def _add_truth( + self, graph: Data, truth_dicts: List[Dict[str, Any]] + ) -> Data: + """Add truth labels from ´truth_dicts´ to ´graph´. + + I.e. ´graph[key] = truth_dict[key]´ + + + Args: + graph: graph where the label will be stored + truth_dicts: dictionary containing the labels + + Returns: + graph with labels + """ + # Write attributes, either target labels, truth info or original + # features. + for truth_dict in truth_dicts: + for key, value in truth_dict.items(): + try: + graph[key] = torch.tensor(value) + except TypeError: + # Cannot convert `value` to Tensor due to its data type, + # e.g. `str`. + self.debug( + ( + f"Could not assign `{key}` with type " + f"'{type(value).__name__}' as attribute to graph." + ) + ) + return graph + + def _add_features_individually( + self, + graph: Data, + node_feature_names: List[str], + ) -> Data: + # Additionally add original features as (static) attributes + graph.features = node_feature_names + for index, feature in enumerate(node_feature_names): + if feature not in ["x"]: # reserved for node features. + graph[feature] = graph.x[:, index].detach() + else: + self.warnonce( + """Cannot assign graph['x']. This field is reserved for node features. Please rename your input feature.""" + ) + return graph + + def _add_custom_labels( + self, + graph: Data, + custom_label_functions: Dict[str, Callable[..., Any]], + ) -> Data: + # Add custom labels to the graph + for key, fn in custom_label_functions.items(): + graph[key] = fn(graph) + return graph
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/graphs.html b/_modules/graphnet/models/graphs/graphs.html new file mode 100644 index 000000000..5f3dc4be7 --- /dev/null +++ b/_modules/graphnet/models/graphs/graphs.html @@ -0,0 +1,410 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.graphs.graphs — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.graphs.graphs

+"""A module containing different graph representations in GraphNeT."""
+
+from typing import List, Optional
+import torch
+
+from graphnet.utilities.config import save_model_config
+from .graph_definition import GraphDefinition
+from graphnet.models.detector import Detector
+from graphnet.models.graphs.edges import EdgeDefinition, KNNEdges
+from graphnet.models.graphs.nodes import NodeDefinition
+
+
+
+[docs] +class KNNGraph(GraphDefinition): + """A Graph representation where Edges are drawn to nearest neighbours.""" + + @save_model_config + def __init__( + self, + detector: Detector, + node_definition: NodeDefinition, + node_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + nb_nearest_neighbours: int = 8, + columns: List[int] = [0, 1, 2], + ) -> None: + """Construct k-nn graph representation. + + Args: + detector: Detector that represents your data. + node_definition: Definition of nodes in the graph. + node_feature_names: Name of node features. + dtype: data type for node features. + nb_nearest_neighbours: Number of edges for each node. Defaults to 8. + columns: node feature columns used for distance calculation + . Defaults to [0, 1, 2]. + """ + # Base class constructor + super().__init__( + detector=detector, + node_definition=node_definition, + edge_definition=KNNEdges( + nb_nearest_neighbours=nb_nearest_neighbours, + columns=columns, + ), + dtype=dtype, + node_feature_names=node_feature_names, + )
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/graphs/nodes/nodes.html b/_modules/graphnet/models/graphs/nodes/nodes.html new file mode 100644 index 000000000..3e7040189 --- /dev/null +++ b/_modules/graphnet/models/graphs/nodes/nodes.html @@ -0,0 +1,444 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.graphs.nodes.nodes — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.graphs.nodes.nodes

+"""Class(es) for building/connecting graphs."""
+
+from typing import List
+from abc import abstractmethod
+
+import torch
+from torch_geometric.data import Data
+
+from graphnet.utilities.decorators import final
+from graphnet.utilities.config import save_model_config
+from graphnet.models import Model
+
+
+
+[docs] +class NodeDefinition(Model): # pylint: disable=too-few-public-methods + """Base class for graph building.""" + + @save_model_config + def __init__(self) -> None: + """Construct `Detector`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + +
+[docs] + @final + def forward(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: a graph without edges + """ + graph = self._construct_nodes(x) + return graph
+ + + @property + def nb_outputs(self) -> int: + """Return number of output features. + + This the default, but may be overridden by specific inheriting classes. + """ + return self.nb_inputs + +
+[docs] + @final + def set_number_of_inputs(self, node_feature_names: List[str]) -> None: + """Return number of inputs expected by node definition. + + Args: + node_feature_names: name of each node feature column. + """ + assert isinstance(node_feature_names, list) + self.nb_inputs = len(node_feature_names)
+ + + @abstractmethod + def _construct_nodes(self, x: torch.tensor) -> Data: + """Construct nodes from raw node features ´x´. + + Args: + x: standardized node features with shape ´[num_pulses, d]´, + where ´d´ is the number of node features. + + Returns: + graph: graph without edges. + """
+ + + +
+[docs] +class NodesAsPulses(NodeDefinition): + """Represent each measured pulse of Cherenkov Radiation as a node.""" + + def _construct_nodes(self, x: torch.Tensor) -> Data: + return Data(x=x)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/model.html b/_modules/graphnet/models/model.html new file mode 100644 index 000000000..9654c1e81 --- /dev/null +++ b/_modules/graphnet/models/model.html @@ -0,0 +1,726 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.model — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.model

+"""Base class(es) for building models."""
+
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+import dill
+import os.path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from pytorch_lightning import Trainer, LightningModule
+from pytorch_lightning.callbacks.callback import Callback
+from pytorch_lightning.callbacks import EarlyStopping
+from pytorch_lightning.loggers.logger import Logger as LightningLogger
+import torch
+from torch import Tensor
+from torch.utils.data import DataLoader, SequentialSampler
+from torch_geometric.data import Data
+
+from graphnet.utilities.logging import Logger
+from graphnet.utilities.config import Configurable, ModelConfig
+from graphnet.training.callbacks import ProgressBar
+
+
+
+[docs] +class Model(Logger, Configurable, LightningModule, ABC): + """Base class for all models in graphnet.""" + +
+[docs] + @abstractmethod + def forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]: + """Forward pass."""
+ + + @staticmethod + def _construct_trainer( + max_epochs: int = 10, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + ckpt_path: Optional[str] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> Trainer: + + if gpus: + accelerator = "gpu" + devices = gpus + else: + accelerator = "cpu" + devices = 1 + + trainer = Trainer( + accelerator=accelerator, + devices=devices, + max_epochs=max_epochs, + callbacks=callbacks, + log_every_n_steps=log_every_n_steps, + logger=logger, + gradient_clip_val=gradient_clip_val, + strategy=distribution_strategy, + default_root_dir=ckpt_path, + **trainer_kwargs, + ) + + return trainer + +
+[docs] + def fit( + self, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader] = None, + *, + max_epochs: int = 10, + gpus: Optional[Union[List[int], int]] = None, + callbacks: Optional[List[Callback]] = None, + ckpt_path: Optional[str] = None, + logger: Optional[LightningLogger] = None, + log_every_n_steps: int = 1, + gradient_clip_val: Optional[float] = None, + distribution_strategy: Optional[str] = "ddp", + **trainer_kwargs: Any, + ) -> None: + """Fit `Model` using `pytorch_lightning.Trainer`.""" + # Checks + if callbacks is None: + callbacks = self._create_default_callbacks( + val_dataloader=val_dataloader, + ) + elif val_dataloader is not None: + callbacks = self._add_early_stopping( + val_dataloader=val_dataloader, callbacks=callbacks + ) + + self.train(mode=True) + trainer = self._construct_trainer( + max_epochs=max_epochs, + gpus=gpus, + callbacks=callbacks, + ckpt_path=ckpt_path, + logger=logger, + log_every_n_steps=log_every_n_steps, + gradient_clip_val=gradient_clip_val, + distribution_strategy=distribution_strategy, + **trainer_kwargs, + ) + + try: + trainer.fit( + self, train_dataloader, val_dataloader, ckpt_path=ckpt_path + ) + except KeyboardInterrupt: + self.warning("[ctrl+c] Exiting gracefully.") + pass
+ + + def _create_default_callbacks(self, val_dataloader: DataLoader) -> List: + callbacks = [ProgressBar()] + callbacks = self._add_early_stopping( + val_dataloader=val_dataloader, callbacks=callbacks + ) + return callbacks + + def _add_early_stopping( + self, val_dataloader: DataLoader, callbacks: List + ) -> List: + if val_dataloader is None: + return callbacks + has_early_stopping = False + assert isinstance(callbacks, list) + for callback in callbacks: + if isinstance(callback, EarlyStopping): + has_early_stopping = True + + if not has_early_stopping: + callbacks.append( + EarlyStopping( + monitor="val_loss", + patience=5, + ) + ) + self.warning_once( + "Got validation dataloader but no EarlyStopping callback. An " + "EarlyStopping callback has been added automatically with " + "patience=5 and monitor = 'val_loss'." + ) + return callbacks + +
+[docs] + def predict( + self, + dataloader: DataLoader, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> List[Tensor]: + """Return predictions for `dataloader`. + + Returns a list of Tensors, one for each model output. + """ + self.train(mode=False) + + callbacks = self._create_default_callbacks( + val_dataloader=None, + ) + + inference_trainer = self._construct_trainer( + gpus=gpus, + distribution_strategy=distribution_strategy, + callbacks=callbacks, + ) + + predictions_list = inference_trainer.predict(self, dataloader) + assert len(predictions_list), "Got no predictions" + + nb_outputs = len(predictions_list[0]) + predictions: List[Tensor] = [ + torch.cat([preds[ix] for preds in predictions_list], dim=0) + for ix in range(nb_outputs) + ] + + return predictions
+ + +
+[docs] + def predict_as_dataframe( + self, + dataloader: DataLoader, + prediction_columns: List[str], + *, + additional_attributes: Optional[List[str]] = None, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> pd.DataFrame: + """Return predictions for `dataloader` as a DataFrame. + + Include `additional_attributes` as additional columns in the output + DataFrame. + """ + # Check(s) + if additional_attributes is None: + additional_attributes = [] + assert isinstance(additional_attributes, list) + + if ( + not isinstance(dataloader.sampler, SequentialSampler) + and additional_attributes + ): + print(dataloader.sampler) + raise UserWarning( + "DataLoader has a `sampler` that is not `SequentialSampler`, " + "indicating that shuffling is enabled. Using " + "`predict_as_dataframe` with `additional_attributes` assumes " + "that the sequence of batches in `dataloader` are " + "deterministic. Either call this method a `dataloader` which " + "doesn't resample batches; or do not request " + "`additional_attributes`." + ) + self.info(f"Column names for predictions are: \n {prediction_columns}") + predictions_torch = self.predict( + dataloader=dataloader, + gpus=gpus, + distribution_strategy=distribution_strategy, + ) + predictions = ( + torch.cat(predictions_torch, dim=1).detach().cpu().numpy() + ) + assert len(prediction_columns) == predictions.shape[1], ( + f"Number of provided column names ({len(prediction_columns)}) and " + f"number of output columns ({predictions.shape[1]}) don't match." + ) + + # Get additional attributes + attributes: Dict[str, List[np.ndarray]] = OrderedDict( + [(attr, []) for attr in additional_attributes] + ) + + for batch in dataloader: + for attr in attributes: + attribute = batch[attr] + if isinstance(attribute, torch.Tensor): + attribute = attribute.detach().cpu().numpy() + + # Check if node level predictions + # If true, additional attributes are repeated + # to make dimensions fit + if len(predictions) != len(dataloader.dataset): + if len(attribute) < np.sum( + batch.n_pulses.detach().cpu().numpy() + ): + attribute = np.repeat( + attribute, batch.n_pulses.detach().cpu().numpy() + ) + try: + assert len(attribute) == len(batch.x) + except AssertionError: + self.warning_once( + "Could not automatically adjust length" + f"of additional attribute {attr} to match length of" + f"predictions. Make sure {attr} is a graph-level or" + "node-level attribute. Attribute skipped." + ) + pass + attributes[attr].extend(attribute) + + data = np.concatenate( + [predictions] + + [ + np.asarray(values)[:, np.newaxis] + for values in attributes.values() + ], + axis=1, + ) + + results = pd.DataFrame( + data, columns=prediction_columns + additional_attributes + ) + return results
+ + +
+[docs] + def save(self, path: str) -> None: + """Save entire model to `path`.""" + if not path.endswith(".pth"): + self.info( + "It is recommended to use the .pth suffix for model files." + ) + dirname = os.path.dirname(path) + if dirname: + os.makedirs(dirname, exist_ok=True) + torch.save(self.cpu(), path, pickle_module=dill) + self.info(f"Model saved to {path}")
+ + +
+[docs] + @classmethod + def load(cls, path: str) -> "Model": + """Load entire model from `path`.""" + return torch.load(path, pickle_module=dill)
+ + +
+[docs] + def save_state_dict(self, path: str) -> None: + """Save model `state_dict` to `path`.""" + if not path.endswith(".pth"): + self.info( + "It is recommended to use the .pth suffix for state_dict files." + ) + torch.save(self.cpu().state_dict(), path) + self.info(f"Model state_dict saved to {path}")
+ + +
+[docs] + def load_state_dict( + self, path: Union[str, Dict], **kargs: Optional[Any] + ) -> "Model": # pylint: disable=arguments-differ + """Load model `state_dict` from `path`.""" + if isinstance(path, str): + state_dict = torch.load(path) + else: + state_dict = path + return super().load_state_dict(state_dict, **kargs)
+ + +
+[docs] + @classmethod + def from_config( # type: ignore[override] + cls, + source: Union[ModelConfig, str], + trust: bool = False, + load_modules: Optional[List[str]] = None, + ) -> "Model": + """Construct `Model` instance from `source` configuration. + + Arguments: + trust: Whether to trust the ModelConfig file enough to `eval(...)` + any lambda function expressions contained. + load_modules: List of modules used in the definition of the model + which, as a consequence, need to be loaded into the global + namespace. Defaults to loading `torch`. + + Raises: + ValueError: If the ModelConfig contains lambda functions but + `trust = False`. + """ + if isinstance(source, str): + source = ModelConfig.load(source) + + assert isinstance( + source, ModelConfig + ), f"Argument `source` of type ({type(source)}) is not a `ModelConfig" + + return source._construct_model(trust, load_modules)
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/standard_model.html b/_modules/graphnet/models/standard_model.html new file mode 100644 index 000000000..cb57cb523 --- /dev/null +++ b/_modules/graphnet/models/standard_model.html @@ -0,0 +1,604 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.standard_model — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.standard_model

+"""Standard model class(es)."""
+
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+from torch import Tensor
+from torch.nn import ModuleList
+from torch.optim import Adam
+from torch.utils.data import DataLoader
+from torch_geometric.data import Data
+import pandas as pd
+
+from graphnet.utilities.config import save_model_config
+from graphnet.models.graphs import GraphDefinition
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.model import Model
+from graphnet.models.task import Task
+
+
+
+[docs] +class StandardModel(Model): + """Main class for standard models in graphnet. + + This class chains together the different elements of a complete GNN-based + model (detector read-in, GNN architecture, and task-specific read-outs). + """ + + @save_model_config + def __init__( + self, + *, + graph_definition: GraphDefinition, + gnn: GNN, + tasks: Union[Task, List[Task]], + optimizer_class: type = Adam, + optimizer_kwargs: Optional[Dict] = None, + scheduler_class: Optional[type] = None, + scheduler_kwargs: Optional[Dict] = None, + scheduler_config: Optional[Dict] = None, + ) -> None: + """Construct `StandardModel`.""" + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + # Check(s) + if isinstance(tasks, Task): + tasks = [tasks] + assert isinstance(tasks, (list, tuple)) + assert all(isinstance(task, Task) for task in tasks) + assert isinstance(graph_definition, GraphDefinition) + assert isinstance(gnn, GNN) + + # Member variable(s) + self._graph_definition = graph_definition + self._gnn = gnn + self._tasks = ModuleList(tasks) + self._optimizer_class = optimizer_class + self._optimizer_kwargs = optimizer_kwargs or dict() + self._scheduler_class = scheduler_class + self._scheduler_kwargs = scheduler_kwargs or dict() + self._scheduler_config = scheduler_config or dict() + + # set dtype of GNN from graph_definition + self._gnn.type(self._graph_definition._dtype) + + @property + def target_labels(self) -> List[str]: + """Return target label.""" + return [label for task in self._tasks for label in task._target_labels] + + @property + def prediction_labels(self) -> List[str]: + """Return prediction labels.""" + return [ + label for task in self._tasks for label in task._prediction_labels + ] + +
+[docs] + def configure_optimizers(self) -> Dict[str, Any]: + """Configure the model's optimizer(s).""" + optimizer = self._optimizer_class( + self.parameters(), **self._optimizer_kwargs + ) + config = { + "optimizer": optimizer, + } + if self._scheduler_class is not None: + scheduler = self._scheduler_class( + optimizer, **self._scheduler_kwargs + ) + config.update( + { + "lr_scheduler": { + "scheduler": scheduler, + **self._scheduler_config, + }, + } + ) + return config
+ + +
+[docs] + def forward(self, data: Data) -> List[Union[Tensor, Data]]: + """Forward pass, chaining model components.""" + assert isinstance(data, Data) + x = self._gnn(data) + preds = [task(x) for task in self._tasks] + return preds
+ + +
+[docs] + def shared_step(self, batch: Data, batch_idx: int) -> Tensor: + """Perform shared step. + + Applies the forward pass and the following loss calculation, shared + between the training and validation step. + """ + preds = self(batch) + loss = self.compute_loss(preds, batch) + return loss
+ + +
+[docs] + def training_step(self, train_batch: Data, batch_idx: int) -> Tensor: + """Perform training step.""" + loss = self.shared_step(train_batch, batch_idx) + self.log( + "train_loss", + loss, + batch_size=self._get_batch_size(train_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + return loss
+ + +
+[docs] + def validation_step(self, val_batch: Data, batch_idx: int) -> Tensor: + """Perform validation step.""" + loss = self.shared_step(val_batch, batch_idx) + self.log( + "val_loss", + loss, + batch_size=self._get_batch_size(val_batch), + prog_bar=True, + on_epoch=True, + on_step=False, + sync_dist=True, + ) + return loss
+ + +
+[docs] + def compute_loss( + self, preds: Tensor, data: Data, verbose: bool = False + ) -> Tensor: + """Compute and sum losses across tasks.""" + losses = [ + task.compute_loss(pred, data) + for task, pred in zip(self._tasks, preds) + ] + if verbose: + self.info(f"{losses}") + assert all( + loss.dim() == 0 for loss in losses + ), "Please reduce loss for each task separately" + return torch.sum(torch.stack(losses))
+ + + def _get_batch_size(self, data: Data) -> int: + return torch.numel(torch.unique(data.batch)) + +
+[docs] + def inference(self) -> None: + """Activate inference mode.""" + for task in self._tasks: + task.inference()
+ + +
+[docs] + def train(self, mode: bool = True) -> "Model": + """Deactivate inference mode.""" + super().train(mode) + if mode: + for task in self._tasks: + task.train_eval() + return self
+ + +
+[docs] + def predict( + self, + dataloader: DataLoader, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> List[Tensor]: + """Return predictions for `dataloader`.""" + self.inference() + return super().predict( + dataloader=dataloader, + gpus=gpus, + distribution_strategy=distribution_strategy, + )
+ + +
+[docs] + def predict_as_dataframe( + self, + dataloader: DataLoader, + prediction_columns: Optional[List[str]] = None, + *, + additional_attributes: Optional[List[str]] = None, + gpus: Optional[Union[List[int], int]] = None, + distribution_strategy: Optional[str] = "auto", + ) -> pd.DataFrame: + """Return predictions for `dataloader` as a DataFrame. + + Include `additional_attributes` as additional columns in the output + DataFrame. + """ + if prediction_columns is None: + prediction_columns = self.prediction_labels + return super().predict_as_dataframe( + dataloader=dataloader, + prediction_columns=prediction_columns, + additional_attributes=additional_attributes, + gpus=gpus, + distribution_strategy=distribution_strategy, + )
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/task/classification.html b/_modules/graphnet/models/task/classification.html new file mode 100644 index 000000000..b3c6c9440 --- /dev/null +++ b/_modules/graphnet/models/task/classification.html @@ -0,0 +1,411 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.task.classification — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.task.classification

+"""Classification-specific `Model` class(es)."""
+
+from typing import Any
+
+import torch
+from torch import Tensor
+
+from graphnet.models.task import Task, IdentityTask
+
+
+
+[docs] +class MulticlassClassificationTask(IdentityTask): + """General task for classifying any number of classes. + + Requires the same number of input features as the number of classes being + predicted. Returns the untransformed latent features, which are interpreted + as the logits for each class being classified. + """
+ + + +
+[docs] +class BinaryClassificationTask(Task): + """Performs binary classification.""" + + # Requires one feature, logit for being signal class. + nb_inputs = 1 + default_target_labels = ["target"] + default_prediction_labels = ["target_pred"] + + def _forward(self, x: Tensor) -> Tensor: + # transform probability of being muon + return torch.sigmoid(x)
+ + + +
+[docs] +class BinaryClassificationTaskLogits(Task): + """Performs binary classification form logits.""" + + # Requires one feature, logit for being signal class. + nb_inputs = 1 + default_target_labels = ["target"] + default_prediction_labels = ["target_pred"] + + def _forward(self, x: Tensor) -> Tensor: + return x
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/task/reconstruction.html b/_modules/graphnet/models/task/reconstruction.html new file mode 100644 index 000000000..ccb644a97 --- /dev/null +++ b/_modules/graphnet/models/task/reconstruction.html @@ -0,0 +1,609 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.task.reconstruction — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.task.reconstruction

+"""Reconstruction-specific `Model` class(es)."""
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from graphnet.models.task import Task
+from graphnet.utilities.maths import eps_like
+
+
+
+[docs] +class AzimuthReconstructionWithKappa(Task): + """Reconstructs azimuthal angle and associated kappa (1/var).""" + + # Requires two features: untransformed points in (x,y)-space. + default_target_labels = ["azimuth"] + default_prediction_labels = ["azimuth_pred", "azimuth_kappa"] + nb_inputs = 2 + + def _forward(self, x: Tensor) -> Tensor: + # Transform outputs to angle and prepare prediction + kappa = torch.linalg.vector_norm(x, dim=1) + eps_like(x) + angle = torch.atan2(x[:, 1], x[:, 0]) + angle = torch.where( + angle < 0, angle + 2 * np.pi, angle + ) # atan(y,x) -> [-pi, pi] + return torch.stack((angle, kappa), dim=1)
+ + + +
+[docs] +class AzimuthReconstruction(AzimuthReconstructionWithKappa): + """Reconstructs azimuthal angle.""" + + # Requires two features: untransformed points in (x,y)-space. + default_target_labels = ["azimuth"] + default_prediction_labels = ["azimuth_pred"] + nb_inputs = 2 + + def _forward(self, x: Tensor) -> Tensor: + # Transform outputs to angle and prepare prediction + res = super()._forward(x) + angle = res[:, 0].unsqueeze(1) + kappa = res[:, 1] + sigma = torch.sqrt(1.0 / kappa) + beta = 1e-3 + kl_loss = torch.mean(sigma**2 - torch.log(sigma) - 1) + self._regularisation_loss += beta * kl_loss + return angle
+ + + +
+[docs] +class DirectionReconstructionWithKappa(Task): + """Reconstructs direction with kappa from the 3D-vMF distribution.""" + + # Requires three features: untransformed points in (x,y,z)-space. + default_target_labels = [ + "direction" + ] # contains dir_x, dir_y, dir_z see https://github.com/graphnet-team/graphnet/blob/95309556cfd46a4046bc4bd7609888aab649e295/src/graphnet/training/labels.py#L29 + default_prediction_labels = [ + "dir_x_pred", + "dir_y_pred", + "dir_z_pred", + "direction_kappa", + ] + nb_inputs = 3 + + def _forward(self, x: Tensor) -> Tensor: + # Transform outputs to angle and prepare prediction + kappa = torch.linalg.vector_norm(x, dim=1) + eps_like(x) + vec_x = x[:, 0] / kappa + vec_y = x[:, 1] / kappa + vec_z = x[:, 2] / kappa + return torch.stack((vec_x, vec_y, vec_z, kappa), dim=1)
+ + + +
+[docs] +class ZenithReconstruction(Task): + """Reconstructs zenith angle.""" + + # Requires two features: zenith angle itself. + default_target_labels = ["zenith"] + default_prediction_labels = ["zenith_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Transform outputs to angle and prepare prediction + return torch.sigmoid(x[:, :1]) * np.pi
+ + + +
+[docs] +class ZenithReconstructionWithKappa(ZenithReconstruction): + """Reconstructs zenith angle and associated kappa (1/var).""" + + # Requires one feature in addition to `ZenithReconstruction`: kappa (unceratinty; 1/variance). + default_target_labels = ["zenith"] + default_prediction_labels = ["zenith_pred", "zenith_kappa"] + nb_inputs = 2 + + def _forward(self, x: Tensor) -> Tensor: + # Transform outputs to angle and prepare prediction + angle = super()._forward(x[:, :1]).squeeze(1) + kappa = torch.abs(x[:, 1]) + eps_like(x) + return torch.stack((angle, kappa), dim=1)
+ + + +
+[docs] +class EnergyReconstruction(Task): + """Reconstructs energy using stable method.""" + + # Requires one feature: untransformed energy + default_target_labels = ["energy"] + default_prediction_labels = ["energy_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Transform to positive energy domain avoiding `-inf` in `log10` + # Transform, thereby preventing overflow and underflow error. + return torch.nn.functional.softplus(x, beta=0.05) + eps_like(x)
+ + + +
+[docs] +class EnergyReconstructionWithPower(Task): + """Reconstructs energy.""" + + # Requires one feature: untransformed energy + default_target_labels = ["energy"] + default_prediction_labels = ["energy_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Transform energy + return torch.pow(10, x[:, 0] + 1.0).unsqueeze(1)
+ + + +
+[docs] +class EnergyReconstructionWithUncertainty(EnergyReconstruction): + """Reconstructs energy and associated uncertainty (log(var)).""" + + # Requires one feature in addition to `EnergyReconstruction`: log-variance (uncertainty). + default_target_labels = ["energy"] + default_prediction_labels = ["energy_pred", "energy_sigma"] + nb_inputs = 2 + + def _forward(self, x: Tensor) -> Tensor: + # Transform energy + energy = super()._forward(x[:, :1]).squeeze(1) + log_var = x[:, 1] + pred = torch.stack((energy, log_var), dim=1) + return pred
+ + + +
+[docs] +class VertexReconstruction(Task): + """Reconstructs vertex position and time.""" + + # Requires four features, x, y, z, and t. + default_target_labels = ["vertex"] + default_prediction_labels = [ + "position_x_pred", + "position_y_pred", + "position_z_pred", + "interaction_time_pred", + ] + nb_inputs = 4 + + def _forward(self, x: Tensor) -> Tensor: + # Scale xyz to roughly the right order of magnitude, leave time + x[:, 0] = x[:, 0] * 1e2 + x[:, 1] = x[:, 1] * 1e2 + x[:, 2] = x[:, 2] * 1e2 + + return x
+ + + +
+[docs] +class PositionReconstruction(Task): + """Reconstructs vertex position.""" + + # Requires three features, x, y, and z. + default_target_labels = ["position"] + default_prediction_labels = [ + "position_x_pred", + "position_y_pred", + "position_z_pred", + ] + nb_inputs = 3 + + def _forward(self, x: Tensor) -> Tensor: + # Scale to roughly the right order of magnitude + x[:, 0] = x[:, 0] * 1e2 + x[:, 1] = x[:, 1] * 1e2 + x[:, 2] = x[:, 2] * 1e2 + + return x
+ + + +
+[docs] +class TimeReconstruction(Task): + """Reconstructs time.""" + + # Requires one feature, time. + default_target_labels = ["interaction_time"] + default_prediction_labels = ["interaction_time_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Leave as it is + return x
+ + + +
+[docs] +class InelasticityReconstruction(Task): + """Reconstructs interaction inelasticity. + + That is, 1-(track energy / hadronic energy). + """ + + # Requires one features: inelasticity itself + default_target_labels = ["inelasticity"] + default_prediction_labels = ["inelasticity_pred"] + nb_inputs = 1 + + def _forward(self, x: Tensor) -> Tensor: + # Transform output to unit range + return torch.sigmoid(x)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/task/task.html b/_modules/graphnet/models/task/task.html new file mode 100644 index 000000000..c43cb2f73 --- /dev/null +++ b/_modules/graphnet/models/task/task.html @@ -0,0 +1,688 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.task.task — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.task.task

+"""Base physics task-specific `Model` class(es)."""
+
+from abc import abstractmethod
+from typing import Any, TYPE_CHECKING, List, Tuple, Union
+from typing import Callable, Optional
+import numpy as np
+
+import torch
+from torch import Tensor
+from torch.nn import Linear
+from torch_geometric.data import Data
+
+if TYPE_CHECKING:
+    # Avoid cyclic dependency
+    from graphnet.training.loss_functions import LossFunction  # type: ignore[attr-defined]
+
+from graphnet.models import Model
+from graphnet.utilities.config import save_model_config
+from graphnet.utilities.decorators import final
+
+
+
+[docs] +class Task(Model): + """Base class for all reconstruction and classification tasks.""" + + @property + @abstractmethod + def nb_inputs(self) -> int: + """Return number of inputs assumed by task.""" + + @property + @abstractmethod + def default_target_labels(self) -> List[str]: + """Return default target labels.""" + return self._default_target_labels + + @property + @abstractmethod + def default_prediction_labels(self) -> List[str]: + """Return default prediction labels.""" + return self._default_prediction_labels + + @save_model_config + def __init__( + self, + *, + hidden_size: int, + loss_function: "LossFunction", + target_labels: Optional[Union[str, List[str]]] = None, + prediction_labels: Optional[Union[str, List[str]]] = None, + transform_prediction_and_target: Optional[Callable] = None, + transform_target: Optional[Callable] = None, + transform_inference: Optional[Callable] = None, + transform_support: Optional[Tuple] = None, + loss_weight: Optional[str] = None, + ): + """Construct `Task`. + + Args: + hidden_size: The number of nodes in the layer feeding into this + tasks, used to construct the affine transformation to the + predicted quantity. + loss_function: Loss function appropriate to the task. + target_labels: Name(s) of the quantity/-ies being predicted, used + to extract the target tensor(s) from the `Data` object in + `.compute_loss(...)`. + prediction_labels: The name(s) of each column that is predicted by + the model during inference. If not given, the name will auto + matically be set to `target_label + _pred`. + transform_prediction_and_target: Optional function to transform + both the predicted and target tensor before passing them to the + loss function. Useful e.g. for having the model predict + quantities on a physical scale, but transforming this scale to + O(1) for a numerically stable loss computation. + transform_target: Optional function to transform only the target + tensor before passing it, and the predicted tensor, to the loss + function. Useful e.g. for having the model predict a + transformed version of the target quantity, e.g. the log10- + scaled energy, rather than the physical quantity itself. Used + in conjunction with `transform_inference` to perform the + inverse transform on the predicted quantity to recover the + physical scale. + transform_inference: Optional function to inverse-transform the + model prediction to recover a physical scale. Used in + conjunction with `transform_target`. + transform_support: Optional tuple to specify minimum and maximum + of the range of validity for the inverse transforms + `transform_target` and `transform_inference` in case this is + restricted. By default the invertibility of `transform_target` + is tested on the range [-1e6, 1e6]. + loss_weight: Name of the attribute in `data` containing per-event + loss weights. + """ + # Base class constructor + super().__init__() + # Check(s) + if target_labels is None: + target_labels = self.default_target_labels + if isinstance(target_labels, str): + target_labels = [target_labels] + + if prediction_labels is None: + prediction_labels = self.default_prediction_labels + if isinstance(prediction_labels, str): + prediction_labels = [prediction_labels] + + assert isinstance(target_labels, List) # mypy + assert isinstance(prediction_labels, List) # mypy + # Member variables + self._regularisation_loss: Optional[float] = None + self._target_labels = target_labels + self._prediction_labels = prediction_labels + self._loss_function = loss_function + self._inference = False + self._loss_weight = loss_weight + + self._transform_prediction_training: Callable[ + [Tensor], Tensor + ] = lambda x: x + self._transform_prediction_inference: Callable[ + [Tensor], Tensor + ] = lambda x: x + self._transform_target: Callable[[Tensor], Tensor] = lambda x: x + self._validate_and_set_transforms( + transform_prediction_and_target, + transform_target, + transform_inference, + transform_support, + ) + + # Mapping from last hidden layer to required size of input + self._affine = Linear(hidden_size, self.nb_inputs) + +
+[docs] + @final + def forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]: + """Forward pass.""" + self._regularisation_loss = 0 # Reset + x = self._affine(x) + x = self._forward(x) + return self._transform_prediction(x)
+ + + @final + def _transform_prediction( + self, prediction: Union[Tensor, Data] + ) -> Union[Tensor, Data]: + if self._inference: + return self._transform_prediction_inference(prediction) + else: + return self._transform_prediction_training(prediction) + + @abstractmethod + def _forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]: + """Syntax like `.forward`, for implentation in inheriting classes.""" + +
+[docs] + @final + def compute_loss(self, pred: Union[Tensor, Data], data: Data) -> Tensor: + """Compute loss of `pred` wrt. + + target labels in `data`. + """ + target = torch.stack( + [data[label] for label in self._target_labels], dim=1 + ) + target = self._transform_target(target) + if self._loss_weight is not None: + weights = data[self._loss_weight] + else: + weights = None + loss = ( + self._loss_function(pred, target, weights=weights) + + self._regularisation_loss + ) + return loss
+ + +
+[docs] + @final + def inference(self) -> None: + """Activate inference mode.""" + self._inference = True
+ + +
+[docs] + @final + def train_eval(self) -> None: + """Deactivate inference mode.""" + self._inference = False
+ + + @final + def _validate_and_set_transforms( + self, + transform_prediction_and_target: Union[Callable, None], + transform_target: Union[Callable, None], + transform_inference: Union[Callable, None], + transform_support: Union[Tuple, None], + ) -> None: + """Validate and set transforms. + + Assert that a valid combination of transformation arguments are passed + and update the corresponding functions. + """ + # Checks + assert not ( + (transform_prediction_and_target is not None) + and (transform_target is not None) + ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`" + if (transform_target is not None) != (transform_inference is not None): + self.warning( + "Setting one of `transform_target` and `transform_inference`, but not " + "the other." + ) + + if transform_target is not None: + assert transform_target is not None + assert transform_inference is not None + + if transform_support is not None: + assert transform_support is not None + + assert ( + len(transform_support) == 2 + ), "Please specify min and max for transformation support." + x_test = torch.from_numpy( + np.linspace(transform_support[0], transform_support[1], 10) + ) + else: + x_test = np.logspace(-6, 6, 12 + 1) + x_test = torch.from_numpy( + np.concatenate([-x_test[::-1], [0], x_test]) + ) + + # Add feature dimension before inference transformation to make it + # match the dimensions of a standard prediction. Remove it again + # before comparison. Temporary + try: + t_test = torch.unsqueeze(transform_target(x_test), -1) + t_test = torch.squeeze(transform_inference(t_test), -1) + valid = torch.isfinite(t_test) + + assert torch.allclose(t_test[valid], x_test[valid]), ( + "The provided transforms for targets during training and " + "predictions during inference are not inverse. Please " + "adjust transformation functions or support." + ) + del x_test, t_test, valid + + except IndexError: + self.warning( + "transform_target and/or transform_inference rely on " + "indexing, which we won't validate. Please make sure that " + "they are mutually inverse, i.e. that\n" + " x = transform_inference(transform_target(x))\n" + "for all x that are within your target range." + ) + + # Set transforms + if transform_prediction_and_target is not None: + self._transform_prediction_training = ( + transform_prediction_and_target + ) + self._transform_target = transform_prediction_and_target + else: + if transform_target is not None: + self._transform_target = transform_target + if transform_inference is not None: + self._transform_prediction_inference = transform_inference
+ + + +
+[docs] +class IdentityTask(Task): + """Identity, or trivial, task.""" + + @save_model_config + def __init__( + self, + nb_outputs: int, + target_labels: Union[List[str], Any], + *args: Any, + **kwargs: Any, + ): + """Construct IdentityTask. + + Return the `nb_outputs` as a direct, affine transformation of the last + hidden layer. + """ + self._nb_inputs = nb_outputs + self._default_target_labels = ( + target_labels + if isinstance(target_labels, list) + else [target_labels] + ) + self._default_prediction_labels = [ + f"target_{i}_pred" for i in range(len(self._default_target_labels)) + ] + + super().__init__(*args, **kwargs) + # Base class constructor + + @property + def default_target_labels(self) -> List[str]: + """Return default target labels.""" + return self._default_target_labels + + @property + def default_prediction_labels(self) -> List[str]: + """Return default prediction labels.""" + return self._default_prediction_labels + + @property + def nb_inputs(self) -> int: + """Return number of inputs assumed by task.""" + return self._nb_inputs + + def _forward(self, x: Tensor) -> Tensor: + # Leave it as is. + return x
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/models/utils.html b/_modules/graphnet/models/utils.html new file mode 100644 index 000000000..35c1b3460 --- /dev/null +++ b/_modules/graphnet/models/utils.html @@ -0,0 +1,430 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.models.utils — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.models.utils

+"""Utility functions for `graphnet.models`."""
+
+from typing import List, Tuple, Union
+from torch_geometric.nn import knn_graph
+from torch_geometric.data import Batch
+import torch
+from torch import Tensor, LongTensor
+
+from torch_geometric.utils.homophily import homophily
+
+
+
+[docs] +def calculate_xyzt_homophily( + x: Tensor, edge_index: LongTensor, batch: Batch +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Calculate xyzt-homophily from a batch of graphs. + + Homophily is a graph scalar quantity that measures the likeness of + variables in nodes. Notice that this calculator assumes a special order of + input features in x. + + Returns: + Tuple, each element with shape [batch_size,1]. + """ + hx = homophily(edge_index, x[:, 0], batch).reshape(-1, 1) + hy = homophily(edge_index, x[:, 1], batch).reshape(-1, 1) + hz = homophily(edge_index, x[:, 2], batch).reshape(-1, 1) + ht = homophily(edge_index, x[:, 3], batch).reshape(-1, 1) + return hx, hy, hz, ht
+ + + +
+[docs] +def calculate_distance_matrix(xyz_coords: Tensor) -> Tensor: + """Calculate the matrix of pairwise distances between pulses. + + Args: + xyz_coords: (x,y,z)-coordinates of pulses, of shape [nb_doms, 3]. + + Returns: + Matrix of pairwise distances, of shape [nb_doms, nb_doms] + """ + diff = xyz_coords.unsqueeze(dim=2) - xyz_coords.T.unsqueeze(dim=0) + return torch.sqrt(torch.sum(diff**2, dim=1))
+ + + +
+[docs] +def knn_graph_batch(batch: Batch, k: List[int], columns: List[int]) -> Batch: + """Calculate k-nearest-neighbours with individual k for each batch event. + + Args: + batch: Batch of events. + k: A list of k's. + columns: The columns of Data.x used for computing the distances. E.g., + Data.x[:,[0,1,2]] + + Returns: + Returns the same batch of events, but with updated edges. + """ + data_list = batch.to_data_list() + for i in range(len(data_list)): + data_list[i].edge_index = knn_graph( + x=data_list[i].x[:, columns], k=k[i] + ) + return Batch.from_data_list(data_list)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/pisa/fitting.html b/_modules/graphnet/pisa/fitting.html index 8be91cc6e..1865353e4 100644 --- a/_modules/graphnet/pisa/fitting.html +++ b/_modules/graphnet/pisa/fitting.html @@ -1169,7 +1169,7 @@

Source code for graphnet.pisa. Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/pisa/plotting.html b/_modules/graphnet/pisa/plotting.html index c1cbdd4a9..6f31c1d90 100644 --- a/_modules/graphnet/pisa/plotting.html +++ b/_modules/graphnet/pisa/plotting.html @@ -528,7 +528,7 @@

Source code for graphnet.pisa Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/training/callbacks.html b/_modules/graphnet/training/callbacks.html new file mode 100644 index 000000000..fcf981d92 --- /dev/null +++ b/_modules/graphnet/training/callbacks.html @@ -0,0 +1,544 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.training.callbacks — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.training.callbacks

+"""Callback class(es) for using during model training."""
+
+import logging
+from typing import Dict, List
+import warnings
+
+import numpy as np
+from tqdm.std import Bar
+
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.callbacks import TQDMProgressBar
+from pytorch_lightning.utilities import rank_zero_only
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+from graphnet.utilities.logging import Logger
+
+
+
+[docs] +class PiecewiseLinearLR(_LRScheduler): + """Interpolate learning rate linearly between milestones.""" + + def __init__( + self, + optimizer: Optimizer, + milestones: List[int], + factors: List[float], + last_epoch: int = -1, + verbose: bool = False, + ): + """Construct `PiecewiseLinearLR`. + + For each milestone, denoting a specified number of steps, a factor + multiplying the base learning rate is specified. For steps between two + milestones, the learning rate is interpolated linearly between the two + closest milestones. For steps before the first milestone, the factor + for the first milestone is used; vice versa for steps after the last + milestone. + + Args: + optimizer: Wrapped optimizer. + milestones: List of step indices. Must be increasing. + factors: List of multiplicative factors. Must be same length as + `milestones`. + last_epoch: The index of the last epoch. + verbose: If ``True``, prints a message to stdout for each update. + """ + # Check(s) + if milestones != sorted(milestones): + raise ValueError("Milestones must be increasing") + if len(milestones) != len(factors): + raise ValueError( + "Only multiplicative factor must be specified for each milestone." + ) + + self.milestones = milestones + self.factors = factors + super().__init__(optimizer, last_epoch, verbose) + + def _get_factor(self) -> np.ndarray: + # Linearly interpolate multiplicative factor between milestones. + return np.interp(self.last_epoch, self.milestones, self.factors) + +
+[docs] + def get_lr(self) -> List[float]: + """Get effective learning rate(s) for each optimizer.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + + return [base_lr * self._get_factor() for base_lr in self.base_lrs]
+
+ + + +
+[docs] +class ProgressBar(TQDMProgressBar): + """Custom progress bar for graphnet. + + Customises the default progress in pytorch-lightning. + """ + + def _common_config(self, bar: Bar) -> Bar: + bar.unit = " batch(es)" + bar.colour = "green" + return bar + +
+[docs] + def init_validation_tqdm(self) -> Bar: + """Override for customisation.""" + bar = super().init_validation_tqdm() + bar = self._common_config(bar) + return bar
+ + +
+[docs] + def init_predict_tqdm(self) -> Bar: + """Override for customisation.""" + bar = super().init_predict_tqdm() + bar = self._common_config(bar) + return bar
+ + +
+[docs] + def init_test_tqdm(self) -> Bar: + """Override for customisation.""" + bar = super().init_test_tqdm() + bar = self._common_config(bar) + return bar
+ + +
+[docs] + def init_train_tqdm(self) -> Bar: + """Override for customisation.""" + bar = super().init_train_tqdm() + bar = self._common_config(bar) + return bar
+ + +
+[docs] + def get_metrics(self, trainer: Trainer, model: LightningModule) -> Dict: + """Override to not show the version number in the logging.""" + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items
+ + +
+[docs] + def on_train_epoch_start( + self, trainer: Trainer, model: LightningModule + ) -> None: + """Print the results of the previous epoch on a separate line. + + This allows the user to see the losses/metrics for previous epochs + while the current is training. The default behaviour in pytorch- + lightning is to overwrite the progress bar from previous epochs. + """ + if trainer.current_epoch > 0: + self.train_progress_bar.set_postfix( + self.get_metrics(trainer, model) + ) + print("") + super().on_train_epoch_start(trainer, model) + self.train_progress_bar.set_description( + f"Epoch {trainer.current_epoch:2d}" + )
+ + +
+[docs] + def on_train_epoch_end( + self, trainer: Trainer, model: LightningModule + ) -> None: + """Log the final progress bar for the epoch to file. + + Don't duplciate to stdout. + """ + super().on_train_epoch_end(trainer, model) + + if rank_zero_only.rank == 0: + # Construct Logger + logger = Logger() + + # Log only to file, not stream + h = logger.handlers[0] + assert isinstance(h, logging.StreamHandler) + level = h.level + h.setLevel(logging.ERROR) + logger.info(str(super().train_progress_bar)) + h.setLevel(level)
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/training/labels.html b/_modules/graphnet/training/labels.html new file mode 100644 index 000000000..b2dd18a77 --- /dev/null +++ b/_modules/graphnet/training/labels.html @@ -0,0 +1,436 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.training.labels — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.training.labels

+"""Class(es) for constructing training labels at runtime."""
+
+from abc import ABC, abstractmethod
+import torch
+from torch_geometric.data import Data
+from graphnet.utilities.logging import Logger
+
+
+
+[docs] +class Label(ABC, Logger): + """Base `Label` class for producing labels from single `Data` instance.""" + + def __init__(self, key: str): + """Construct `Label`. + + Args: + key: The name of the field in `Data` where the label will be + stored. That is, `graph[key] = label`. + """ + self._key = key + + # Base class constructor + super().__init__(name=__name__, class_name=self.__class__.__name__) + + @property + def key(self) -> str: + """Return value of `key`.""" + return self._key + + @abstractmethod + def __call__(self, graph: Data) -> torch.tensor: + """Label-specific implementation."""
+ + + +
+[docs] +class Direction(Label): + """Class for producing particle direction/pointing label.""" + + def __init__( + self, + key: str = "direction", + azimuth_key: str = "azimuth", + zenith_key: str = "zenith", + ): + """Construct `Direction`. + + Args: + key: The name of the field in `Data` where the label will be + stored. That is, `graph[key] = label`. + azimuth_key: The name of the pre-existing key in `graph` that will + be used to access the azimiuth angle, used when calculating + the direction. + zenith_key: The name of the pre-existing key in `graph` that will + be used to access the zenith angle, used when calculating the + direction. + """ + self._azimuth_key = azimuth_key + self._zenith_key = zenith_key + + # Base class constructor + super().__init__(key=key) + + def __call__(self, graph: Data) -> torch.tensor: + """Compute label for `graph`.""" + x = torch.cos(graph[self._azimuth_key]) * torch.sin( + graph[self._zenith_key] + ).reshape(-1, 1) + y = torch.sin(graph[self._azimuth_key]) * torch.sin( + graph[self._zenith_key] + ).reshape(-1, 1) + z = torch.cos(graph[self._zenith_key]).reshape(-1, 1) + return torch.cat((x, y, z), dim=1)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/training/loss_functions.html b/_modules/graphnet/training/loss_functions.html new file mode 100644 index 000000000..dd70fd272 --- /dev/null +++ b/_modules/graphnet/training/loss_functions.html @@ -0,0 +1,859 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.training.loss_functions — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.training.loss_functions

+"""Collection of loss functions.
+
+All loss functions inherit from `LossFunction` which ensures a common syntax,
+handles per-event weights, etc.
+"""
+
+from abc import abstractmethod
+from typing import Any, Optional, Union, List, Dict
+
+import numpy as np
+import scipy.special
+import torch
+from torch import Tensor
+from torch import nn
+from torch.nn.functional import (
+    one_hot,
+    cross_entropy,
+    binary_cross_entropy,
+    softplus,
+)
+
+from graphnet.utilities.config import save_model_config
+from graphnet.models.model import Model
+from graphnet.utilities.decorators import final
+
+
+
+[docs] +class LossFunction(Model): + """Base class for loss functions in `graphnet`.""" + + @save_model_config + def __init__(self, **kwargs: Any) -> None: + """Construct `LossFunction`, saving model config.""" + super().__init__(**kwargs) + +
+[docs] + @final + def forward( # type: ignore[override] + self, + prediction: Tensor, + target: Tensor, + weights: Optional[Tensor] = None, + return_elements: bool = False, + ) -> Tensor: + """Forward pass for all loss functions. + + Args: + prediction: Tensor containing predictions. Shape [N,P] + target: Tensor containing targets. Shape [N,T] + return_elements: Whether elementwise loss terms should be returned. + The alternative is to return the averaged loss across examples. + + Returns: + Loss, either averaged to a scalar (if `return_elements = False`) or + elementwise terms with shape [N,] (if `return_elements = True`). + """ + elements = self._forward(prediction, target) + if weights is not None: + elements = elements * weights + assert elements.size(dim=0) == target.size( + dim=0 + ), "`_forward` should return elementwise loss terms." + + return elements if return_elements else torch.mean(elements)
+ + + @abstractmethod + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Syntax like `.forward`, for implentation in inheriting classes."""
+ + + +
+[docs] +class MSELoss(LossFunction): + """Mean squared error loss.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Implement loss calculation.""" + # Check(s) + assert prediction.dim() == 2 + assert prediction.size() == target.size() + + elements = torch.mean((prediction - target) ** 2, dim=-1) + return elements
+ + + +
+[docs] +class RMSELoss(MSELoss): + """Root mean squared error loss.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Implement loss calculation.""" + # Check(s) + elements = super()._forward(prediction, target) + elements = torch.sqrt(elements) + return elements
+ + + +
+[docs] +class LogCoshLoss(LossFunction): + """Log-cosh loss function. + + Acts like x^2 for small x; and like |x| for large x. + """ + + @classmethod + def _log_cosh(cls, x: Tensor) -> Tensor: # pylint: disable=invalid-name + """Numerically stable version on log(cosh(x)). + + Used to avoid `inf` for even moderately large differences. + See [https://github.com/keras-team/keras/blob/v2.6.0/keras/losses.py#L1580-L1617] + """ + return x + softplus(-2.0 * x) - np.log(2.0) + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Implement loss calculation.""" + diff = prediction - target + elements = self._log_cosh(diff) + return elements
+ + + +
+[docs] +class CrossEntropyLoss(LossFunction): + """Compute cross-entropy loss for classification tasks. + + Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax'ed + probabilities), and targets are an [N,1]-matrix with integer values in + (0, num_classes - 1). + """ + + @save_model_config + def __init__( + self, + options: Union[int, List[Any], Dict[Any, int]], + *args: Any, + **kwargs: Any, + ): + """Construct CrossEntropyLoss.""" + # Base class constructor + super().__init__(*args, **kwargs) + + # Member variables + self._options = options + self._nb_classes: int + if isinstance(self._options, int): + assert self._options in [torch.int32, torch.int64] + assert ( + self._options >= 2 + ), f"Minimum of two classes required. Got {self._options}." + self._nb_classes = options # type: ignore + elif isinstance(self._options, list): + self._nb_classes = len(self._options) # type: ignore + elif isinstance(self._options, dict): + self._nb_classes = len( + np.unique(list(self._options.values())) + ) # type: ignore + else: + raise ValueError( + f"Class options of type {type(self._options)} not supported" + ) + + self._loss = nn.CrossEntropyLoss(reduction="none") + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Transform outputs to angle and prepare prediction.""" + if isinstance(self._options, int): + # Integer number of classes: Targets are expected to be in + # (0, nb_classes - 1). + + # Target integers are positive + assert torch.all(target >= 0) + + # Target integers are consistent with the expected number of class. + assert torch.all(target < self._options) + + assert target.dtype in [torch.int32, torch.int64] + target_integer = target + + elif isinstance(self._options, list): + # List of classes: Mapping target classes in list onto + # (0, nb_classes - 1). Example: + # Given options: [1, 12, 13, ...] + # Yields: [1, 13, 12] -> [0, 2, 1, ...] + target_integer = torch.tensor( + [self._options.index(value) for value in target] + ) + + elif isinstance(self._options, dict): + # Dictionary of classes: Mapping target classes in dict onto + # (0, nb_classes - 1). Example: + # Given options: {1: 0, -1: 0, 12: 1, -12: 1, ...} + # Yields: [1, -1, -12, ...] -> [0, 0, 1, ...] + target_integer = torch.tensor( + [self._options[int(value)] for value in target] + ) + + else: + assert False, "Shouldn't reach here." + + target_one_hot: Tensor = one_hot(target_integer, self._nb_classes).to( + prediction.device + ) + + return self._loss(prediction.float(), target_one_hot.float())
+ + + +
+[docs] +class BinaryCrossEntropyLoss(LossFunction): + """Compute binary cross entropy loss. + + Predictions are vector probabilities (i.e., values between 0 and 1), and + targets should be 0 and 1. + """ + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + return binary_cross_entropy( + prediction.float(), target.float(), reduction="none" + )
+ + + +
+[docs] +class LogCMK(torch.autograd.Function): + """MIT License. + + Copyright (c) 2019 Max Ryabinin + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + _____________________ + + From [https://github.com/mryab/vmf_loss/blob/master/losses.py] + Modified to use modified Bessel function instead of exponentially scaled ditto + (i.e. `.ive` -> `.iv`) as indiciated in [1812.04616] in spite of suggestion in + Sec. 8.2 of this paper. The change has been validated through comparison with + exact calculations for `m=2` and `m=3` and found to yield the correct results. + """ + +
+[docs] + @staticmethod + def forward( + ctx: Any, m: int, kappa: Tensor + ) -> Tensor: # pylint: disable=invalid-name,arguments-differ + """Forward pass.""" + dtype = kappa.dtype + ctx.save_for_backward(kappa) + ctx.m = m + ctx.dtype = dtype + kappa = kappa.double() + iv = torch.from_numpy( + scipy.special.iv(m / 2.0 - 1, kappa.cpu().numpy()) + ).to(kappa.device) + return ( + (m / 2.0 - 1) * torch.log(kappa) + - torch.log(iv) + - (m / 2) * np.log(2 * np.pi) + ).type(dtype)
+ + +
+[docs] + @staticmethod + def backward( + ctx: Any, grad_output: Tensor + ) -> Tensor: # pylint: disable=invalid-name,arguments-differ + """Backward pass.""" + kappa = ctx.saved_tensors[0] + m = ctx.m + dtype = ctx.dtype + kappa = kappa.double().cpu().numpy() + grads = -( + (scipy.special.iv(m / 2.0, kappa)) + / (scipy.special.iv(m / 2.0 - 1, kappa)) + ) + return ( + None, + grad_output + * torch.from_numpy(grads).to(grad_output.device).type(dtype), + )
+
+ + + +
+[docs] +class VonMisesFisherLoss(LossFunction): + """General class for calculating von Mises-Fisher loss. + + Requires implementation for specific dimension `m` in which the target and + prediction vectors need to be prepared. + """ + +
+[docs] + @classmethod + def log_cmk_exact( + cls, m: int, kappa: Tensor + ) -> Tensor: # pylint: disable=invalid-name + """Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly.""" + return LogCMK.apply(m, kappa)
+ + +
+[docs] + @classmethod + def log_cmk_approx( + cls, m: int, kappa: Tensor + ) -> Tensor: # pylint: disable=invalid-name + """Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx. + + [https://arxiv.org/abs/1812.04616] Sec. 8.2 with additional minus sign. + """ + v = m / 2.0 - 0.5 + a = torch.sqrt((v + 1) ** 2 + kappa**2) + b = v - 1 + return -a + b * torch.log(b + a)
+ + +
+[docs] + @classmethod + def log_cmk( + cls, m: int, kappa: Tensor, kappa_switch: float = 100.0 + ) -> Tensor: # pylint: disable=invalid-name + """Calculate $log C_{m}(k)$ term in von Mises-Fisher loss. + + Since `log_cmk_exact` is diverges for `kappa` >~ 700 (using float64 + precision), and since `log_cmk_approx` is unaccurate for small `kappa`, + this method automatically switches between the two at `kappa_switch`, + ensuring continuity at this point. + """ + kappa_switch = torch.tensor([kappa_switch]).to(kappa.device) + mask_exact = kappa < kappa_switch + + # Ensure continuity at `kappa_switch` + offset = cls.log_cmk_approx(m, kappa_switch) - cls.log_cmk_exact( + m, kappa_switch + ) + ret = cls.log_cmk_approx(m, kappa) - offset + ret[mask_exact] = cls.log_cmk_exact(m, kappa[mask_exact]) + return ret
+ + + def _evaluate(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for a vector in D dimensons. + + This loss utilises the von Mises-Fisher distribution, which is a + probability distribution on the (D - 1) sphere in D-dimensional space. + + Args: + prediction: Predicted vector, of shape [batch_size, D]. + target: Target unit vector, of shape [batch_size, D]. + + Returns: + Elementwise von Mises-Fisher loss terms. + """ + # Check(s) + assert prediction.dim() == 2 + assert target.dim() == 2 + assert prediction.size() == target.size() + + # Computing loss + m = target.size()[1] + k = torch.norm(prediction, dim=1) + dotprod = torch.sum(prediction * target, dim=1) + elements = -self.log_cmk(m, k) - dotprod + return elements + + @abstractmethod + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + raise NotImplementedError
+ + + +
+[docs] +class VonMisesFisher2DLoss(VonMisesFisherLoss): + """von Mises-Fisher loss function vectors in the 2D plane.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for an angle in the 2D plane. + + Args: + prediction: Output of the model. Must have shape [N, 2] where 0th + column is a prediction of `angle` and 1st column is an estimate + of `kappa`. + target: Target tensor, extracted from graph object. + + Returns: + loss: Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + # Check(s) + assert prediction.dim() == 2 and prediction.size()[1] == 2 + assert target.dim() == 2 + assert prediction.size()[0] == target.size()[0] + + # Formatting target + angle_true = target[:, 0] + t = torch.stack( + [ + torch.cos(angle_true), + torch.sin(angle_true), + ], + dim=1, + ) + + # Formatting prediction + angle_pred = prediction[:, 0] + kappa = prediction[:, 1] + p = kappa.unsqueeze(1) * torch.stack( + [ + torch.cos(angle_pred), + torch.sin(angle_pred), + ], + dim=1, + ) + + return self._evaluate(p, t)
+ + + +
+[docs] +class EuclideanDistanceLoss(LossFunction): + """Mean squared error in three dimensions.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate 3D Euclidean distance between predicted and target. + + Args: + prediction: Output of the model. Must have shape [N, 3] + target: Target tensor, extracted from graph object. + + Returns: + Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + return torch.sqrt( + (prediction[:, 0] - target[:, 0]) ** 2 + + (prediction[:, 1] - target[:, 1]) ** 2 + + (prediction[:, 2] - target[:, 2]) ** 2 + )
+ + + +
+[docs] +class VonMisesFisher3DLoss(VonMisesFisherLoss): + """von Mises-Fisher loss function vectors in the 3D plane.""" + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate von Mises-Fisher loss for a direction in the 3D. + + Args: + prediction: Output of the model. Must have shape [N, 4] where + columns 0, 1, 2 are predictions of `direction` and last column + is an estimate of `kappa`. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise von Mises-Fisher loss terms. Shape [N,] + """ + target = target.reshape(-1, 3) + # Check(s) + assert prediction.dim() == 2 and prediction.size()[1] == 4 + assert target.dim() == 2 + assert prediction.size()[0] == target.size()[0] + + kappa = prediction[:, 3] + p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] + return self._evaluate(p, target)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/training/utils.html b/_modules/graphnet/training/utils.html new file mode 100644 index 000000000..eff87f698 --- /dev/null +++ b/_modules/graphnet/training/utils.html @@ -0,0 +1,656 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.training.utils — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.training.utils

+"""Utility functions for `graphnet.training`."""
+
+from collections import OrderedDict
+import os
+from typing import Dict, List, Optional, Tuple, Union, Callable
+
+import numpy as np
+import pandas as pd
+from pytorch_lightning import Trainer
+from sklearn.model_selection import train_test_split
+from torch.utils.data import DataLoader
+from torch_geometric.data import Batch, Data
+
+from graphnet.data.dataset import Dataset
+from graphnet.data.dataset import SQLiteDataset
+from graphnet.data.dataset import ParquetDataset
+from graphnet.models import Model
+from graphnet.utilities.logging import Logger
+from graphnet.models.graphs import GraphDefinition
+
+
+
+[docs] +def collate_fn(graphs: List[Data]) -> Batch: + """Remove graphs with less than two DOM hits. + + Should not occur in "production. + """ + graphs = [g for g in graphs if g.n_pulses > 1] + return Batch.from_data_list(graphs)
+ + + +# @TODO: Remove in favour of DataLoader{,.from_dataset_config} +
+[docs] +def make_dataloader( + db: str, + pulsemaps: Union[str, List[str]], + graph_definition: Optional[GraphDefinition], + features: List[str], + truth: List[str], + *, + batch_size: int, + shuffle: bool, + selection: Optional[List[int]] = None, + num_workers: int = 10, + persistent_workers: bool = True, + node_truth: List[str] = None, + truth_table: str = "truth", + node_truth_table: Optional[str] = None, + string_selection: List[int] = None, + loss_weight_table: Optional[str] = None, + loss_weight_column: Optional[str] = None, + index_column: str = "event_no", + labels: Optional[Dict[str, Callable]] = None, +) -> DataLoader: + """Construct `DataLoader` instance.""" + # Check(s) + if isinstance(pulsemaps, str): + pulsemaps = [pulsemaps] + + dataset = SQLiteDataset( + path=db, + pulsemaps=pulsemaps, + features=features, + truth=truth, + selection=selection, + node_truth=node_truth, + truth_table=truth_table, + node_truth_table=node_truth_table, + string_selection=string_selection, + loss_weight_table=loss_weight_table, + loss_weight_column=loss_weight_column, + index_column=index_column, + graph_definition=graph_definition, + ) + + # adds custom labels to dataset + if isinstance(labels, dict): + for label in labels.keys(): + dataset.add_label(key=label, fn=labels[label]) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + persistent_workers=persistent_workers, + prefetch_factor=2, + ) + + return dataloader
+ + + +# @TODO: Remove in favour of DataLoader{,.from_dataset_config} +
+[docs] +def make_train_validation_dataloader( + db: str, + graph_definition: Optional[GraphDefinition], + selection: Optional[List[int]], + pulsemaps: Union[str, List[str]], + features: List[str], + truth: List[str], + *, + batch_size: int, + database_indices: Optional[List[int]] = None, + seed: int = 42, + test_size: float = 0.33, + num_workers: int = 10, + persistent_workers: bool = True, + node_truth: Optional[str] = None, + truth_table: str = "truth", + node_truth_table: Optional[str] = None, + string_selection: Optional[List[int]] = None, + loss_weight_column: Optional[str] = None, + loss_weight_table: Optional[str] = None, + index_column: str = "event_no", + labels: Optional[Dict[str, Callable]] = None, +) -> Tuple[DataLoader, DataLoader]: + """Construct train and test `DataLoader` instances.""" + # Reproducibility + rng = np.random.default_rng(seed=seed) + # Checks(s) + if isinstance(pulsemaps, str): + pulsemaps = [pulsemaps] + + if selection is None: + # If no selection is provided, use all events in dataset. + dataset: Dataset + if db.endswith(".db"): + dataset = SQLiteDataset( + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, + truth_table=truth_table, + index_column=index_column, + ) + elif db.endswith(".parquet"): + dataset = ParquetDataset( + path=db, + graph_definition=graph_definition, + pulsemaps=pulsemaps, + features=features, + truth=truth, + truth_table=truth_table, + index_column=index_column, + ) + else: + raise RuntimeError( + f"File {db} with format {db.split('.'[-1])} not supported." + ) + selection = dataset._get_all_indices() + + # Perform train/validation split + if isinstance(db, list): + df_for_shuffle = pd.DataFrame( + {"event_no": selection, "db": database_indices} + ) + shuffled_df = df_for_shuffle.sample( + frac=1, replace=False, random_state=rng + ) + training_df, validation_df = train_test_split( + shuffled_df, test_size=test_size, random_state=seed + ) + training_selection = training_df.values.tolist() + validation_selection = validation_df.values.tolist() + else: + training_selection, validation_selection = train_test_split( + selection, test_size=test_size, random_state=seed + ) + + # Create DataLoaders + common_kwargs = dict( + db=db, + pulsemaps=pulsemaps, + features=features, + truth=truth, + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=persistent_workers, + node_truth=node_truth, + truth_table=truth_table, + node_truth_table=node_truth_table, + string_selection=string_selection, + loss_weight_column=loss_weight_column, + loss_weight_table=loss_weight_table, + index_column=index_column, + labels=labels, + graph_definition=graph_definition, + ) + + training_dataloader = make_dataloader( + shuffle=True, + selection=training_selection, + **common_kwargs, # type: ignore[arg-type] + ) + + validation_dataloader = make_dataloader( + shuffle=False, + selection=validation_selection, + **common_kwargs, # type: ignore[arg-type] + ) + + return ( + training_dataloader, + validation_dataloader, + )
+ + + +# @TODO: Remove in favour of Model.predict{,_as_dataframe} +
+[docs] +def get_predictions( + trainer: Trainer, + model: Model, + dataloader: DataLoader, + prediction_columns: List[str], + *, + node_level: bool = False, + additional_attributes: Optional[List[str]] = None, +) -> pd.DataFrame: + """Get `model` predictions on `dataloader`.""" + # Gets predictions from model on the events in the dataloader. + # NOTE: dataloader must NOT have shuffle = True! + + # Check(s) + if additional_attributes is None: + additional_attributes = [] + assert isinstance(additional_attributes, list) + + # Set model to inference mode + model.inference() + + # Get predictions + predictions_torch = trainer.predict(model, dataloader) + predictions_list = [ + p[0].detach().cpu().numpy() for p in predictions_torch + ] # Assuming single task + predictions = np.concatenate(predictions_list, axis=0) + try: + assert len(prediction_columns) == predictions.shape[1] + except IndexError: + predictions = predictions.reshape((-1, 1)) + assert len(prediction_columns) == predictions.shape[1] + + # Get additional attributes + attributes: Dict[str, List[np.ndarray]] = OrderedDict( + [(attr, []) for attr in additional_attributes] + ) + for batch in dataloader: + for attr in attributes: + attribute = batch[attr].detach().cpu().numpy() + if node_level: + if attr == "event_no": + attribute = np.repeat( + attribute, batch["n_pulses"].detach().cpu().numpy() + ) + attributes[attr].extend(attribute) + + data = np.concatenate( + [predictions] + + [ + np.asarray(values)[:, np.newaxis] for values in attributes.values() + ], + axis=1, + ) + + results = pd.DataFrame( + data, columns=prediction_columns + additional_attributes + ) + return results
+ + + +# @TODO: Remove +
+[docs] +def save_results( + db: str, tag: str, results: pd.DataFrame, archive: str, model: Model +) -> None: + """Save trained model and prediction `results` in `db`.""" + db_name = db.split("/")[-1].split(".")[0] + path = archive + "/" + db_name + "/" + tag + os.makedirs(path, exist_ok=True) + results.to_csv(path + "/results.csv") + model.save_state_dict(path + "/" + tag + "_state_dict.pth") + model.save(path + "/" + tag + "_model.pth") + Logger().info("Results saved at: \n %s" % path)
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/training/weight_fitting.html b/_modules/graphnet/training/weight_fitting.html index 9811e69cf..909eb3dbe 100644 --- a/_modules/graphnet/training/weight_fitting.html +++ b/_modules/graphnet/training/weight_fitting.html @@ -556,7 +556,7 @@

Source code for gra Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/utilities/argparse.html b/_modules/graphnet/utilities/argparse.html index 43877f0e5..a68f077b2 100644 --- a/_modules/graphnet/utilities/argparse.html +++ b/_modules/graphnet/utilities/argparse.html @@ -515,7 +515,7 @@

Source code for graphnet Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/utilities/config/base_config.html b/_modules/graphnet/utilities/config/base_config.html new file mode 100644 index 000000000..f57c1f185 --- /dev/null +++ b/_modules/graphnet/utilities/config/base_config.html @@ -0,0 +1,449 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.base_config — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.base_config

+"""Base config class(es)."""
+
+from abc import abstractmethod
+from collections import OrderedDict
+import inspect
+import sys
+from typing import Any, Callable, Dict, Optional
+
+from pydantic import BaseModel
+import ruamel.yaml as yaml
+
+
+CONFIG_FILES_SUFFIXES = (".yml", ".yaml")
+
+
+
+[docs] +class BaseConfig(BaseModel): + """Base class for Configs.""" + +
+[docs] + @classmethod + def load(cls, path: str) -> "BaseConfig": + """Load BaseConfig from `path`.""" + assert path.endswith( + CONFIG_FILES_SUFFIXES + ), "Please specify YAML config file." + with open(path, "r") as f: + yaml_ = yaml.YAML(typ="safe", pure=True) + config_dict = yaml_.load(f) + + return cls(**config_dict)
+ + +
+[docs] + def dump(self, path: Optional[str] = None) -> Optional[str]: + """Save BaseConfig to `path` as YAML file, or return as string.""" + config_dict = self.as_dict()[self.__class__.__name__] + yaml_ = yaml.YAML(typ="safe", pure=True) + if path: + if not path.endswith(CONFIG_FILES_SUFFIXES): + path += CONFIG_FILES_SUFFIXES[0] + with open(path, "w") as f: + yaml_.dump(config_dict, f) + return None + else: + return yaml_.dump(config_dict, sys.stdout)
+ + +
+[docs] + def as_dict(self) -> Dict[str, Dict[str, Any]]: + """Represent BaseConfig as a dict. + + This builds on `BaseModel.dict()` but can be overwritten. + """ + return {self.__class__.__name__: self.dict()}
+
+ + + +
+[docs] +def get_all_argument_values( + fn: Callable, *args: Any, **kwargs: Any +) -> Dict[str, Any]: + """Return dict of all argument values to `fn`, including defaults.""" + # Get all default argument values + cfg = OrderedDict() + for key, param in inspect.signature(fn).parameters.items(): + # Don't save `self`, `*args`, or `**kwargs` + if key == "self" or param.kind in [ + param.VAR_POSITIONAL, + param.VAR_KEYWORD, + ]: + continue + cfg[key] = param.default + + # Add positional arguments + for key, val in zip(cfg.keys(), args): + cfg[key] = val + + # Add keyword arguments + cfg.update(kwargs) + + return cfg
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/configurable.html b/_modules/graphnet/utilities/config/configurable.html new file mode 100644 index 000000000..cfe023b65 --- /dev/null +++ b/_modules/graphnet/utilities/config/configurable.html @@ -0,0 +1,408 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.configurable — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.configurable

+"""Bases for all configurable classes in  `graphnet`."""
+
+from abc import ABC, abstractclassmethod
+from typing import Any, Union
+
+from graphnet.utilities.config.base_config import BaseConfig
+from graphnet.utilities.decorators import final
+
+
+
+[docs] +class Configurable(ABC): + """Base class for all configurable classes in graphnet.""" + + def __init__(self) -> None: + """Construct `Configurable`.""" + self._config: BaseConfig + + # Base class constructor + super().__init__() + + @final + @property + def config(self) -> BaseConfig: + """Return configuration to re-create the instance.""" + try: + return self._config + except AttributeError: + raise AttributeError( + "Config was not set. " + "Did you wrap the class constructor with `save_config`?" + ) + +
+[docs] + @final + def save_config(self, path: str) -> None: + """Save Config to `path` as YAML file.""" + self.config.dump(path)
+ + +
+[docs] + @abstractclassmethod + def from_config(cls, source: Union[BaseConfig, str]) -> Any: + """Construct instance from `source` configuration."""
+
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/dataset_config.html b/_modules/graphnet/utilities/config/dataset_config.html new file mode 100644 index 000000000..91b338109 --- /dev/null +++ b/_modules/graphnet/utilities/config/dataset_config.html @@ -0,0 +1,585 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.dataset_config — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.dataset_config

+"""Config classes for the `graphnet.data.dataset` module."""
+
+from functools import wraps
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Union,
+)
+
+from graphnet.utilities.config.base_config import (
+    BaseConfig,
+    get_all_argument_values,
+)
+from graphnet.utilities.config.parsing import traverse_and_apply
+from .model_config import ModelConfig
+
+if TYPE_CHECKING:
+    from graphnet.models import Model
+
+
+BACKEND_LOOKUP = {
+    "db": "sqlite",
+    "parquet": "parquet",
+}
+
+
+
+[docs] +class DatasetConfig(BaseConfig): + """Configuration for all `Dataset`s.""" + + # Fields + path: Union[str, List[str]] + pulsemaps: Union[str, List[str]] + features: List[str] + truth: List[str] + node_truth: Optional[List[str]] = None + index_column: str = "event_no" + truth_table: str = "truth" + node_truth_table: Optional[str] = None + string_selection: Optional[List[int]] = None + selection: Optional[ + Union[ + str, + List[str], + List[Union[int, List[int]]], + Dict[str, Union[str, List[str]]], + ] + ] = None + loss_weight_table: Optional[str] = None + loss_weight_column: Optional[str] = None + loss_weight_default_value: Optional[float] = None + seed: Optional[int] = None + graph_definition: Any = None + + def __init__(self, **data: Any) -> None: + """Construct `DataConfig`. + + Can be used for dataset configuration as code, thereby making dataset + construction more transparent and reproducible. + + Examples: + In one session, do: + + >>> dataset = Dataset(...) + >>> dataset.config.dump() + path: (...) + pulsemaps: + - (...) + (...) + >>> dataset.config.dump("dataset.yml") + + In another session, you can then do: + >>> dataset = Dataset.from_config("dataset.yml") + + # Uniquely for `DatasetConfig`, you can also define and load + # multiple datasets + >>> dataset.config.selection = { + "train": "event_no % 2 == 0", + "test": "event_no % 2 == 1", + } + >>> dataset.config.dump("dataset.yml") + >>> datasets: Dict[str, Dataset] = Dataset.from_config( + "dataset.yml" + ) + >>> datasets + { + "train": Dataset(...), + "test": Dataset(...), + } + + # You can also combine multiple selections into a single, named + # dataset + >>> dataset.config.selection = { + "train": [ + "event_no % 2 == 0 & abs(pid) == 12", + "event_no % 2 == 0 & abs(pid) == 14", + "event_no % 2 == 0 & abs(pid) == 16", + ], + (...) + } + >>> dataset.config.dump("dataset.yml") + >>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config( + "dataset.yml" + ) + >>> datasets + { + "train": EnsembleDataset(...), + (...) + } + + # Finally, you can still reference existing selection files in CSV + # or JSON formats: + >>> dataset.config.selection = { + "train": "50000 random events ~ train_selection.csv", + "test": "test_selection.csv", + } + """ + # Single-key dictioaries are unpacked + if isinstance(data["selection"], dict) and len(data["selection"]) == 1: + data["selection"] = next(iter(data["selection"].values())) + + # Base class constructor + super().__init__(**data) + + @property + def _backend(self) -> str: + path: str + if isinstance(self.path, list): + path = self.path[0] + else: + assert isinstance(self.path, str) + path = self.path + suffix = path.split(".")[-1] + try: + return BACKEND_LOOKUP[suffix] + except KeyError: + self.error( + f"Dataset at `path` {self.path} with suffix {suffix} not " + "supported." + ) + raise + + @property + def _dataset_class(self) -> type: + """Return the `Dataset` class implementation for this configuration.""" + from graphnet.data.dataset.sqlite import SQLiteDataset + from graphnet.data.dataset.parquet import ParquetDataset + + dataset_class = { + "sqlite": SQLiteDataset, + "parquet": ParquetDataset, + }[self._backend] + + return dataset_class + +
+[docs] + def as_dict(self) -> Dict[str, Dict[str, Any]]: + """Represent ModelConfig as a dict. + + This builds on `BaseModel.dict()` but wraps the output in a single-key + dictionary to make it unambiguous to identify model arguments that are + themselves models. + """ + config_dict = self.dict() + config_dict = traverse_and_apply( + obj=dict(**config_dict), fn=self._parse_torch + ) + return {self.__class__.__name__: config_dict}
+ + + def _parse_torch(self, obj: Any) -> Any: + import torch + + if isinstance(obj, torch.dtype): + return obj.__str__() + else: + return obj
+ + + +
+[docs] +def save_dataset_config(init_fn: Callable) -> Callable: + """Save the arguments to `__init__` functions as member `DatasetConfig`.""" + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + import torch + + if isinstance(obj, Model): + return obj.config + + if isinstance(obj, torch.dtype): + return obj.__str__() + + else: + return obj + + @wraps(init_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + """Set `DatasetConfig` after calling `init_fn`.""" + # Call wrapped method + ret = init_fn(self, *args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(init_fn, *args, **kwargs) + + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + # Add `DatasetConfig` as member variables + self._config = DatasetConfig(**cfg) + + return ret + + return wrapper
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/model_config.html b/_modules/graphnet/utilities/config/model_config.html new file mode 100644 index 000000000..e63e73186 --- /dev/null +++ b/_modules/graphnet/utilities/config/model_config.html @@ -0,0 +1,654 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.model_config — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.model_config

+"""Config classes for the `graphnet.models` module."""
+from functools import wraps
+import inspect
+import re
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Union,
+)
+import torch
+
+from graphnet.utilities.config.base_config import (
+    BaseConfig,
+    get_all_argument_values,
+)
+from graphnet.utilities.config.parsing import (
+    traverse_and_apply,
+    get_all_grapnet_classes,
+)
+
+if TYPE_CHECKING:
+    from graphnet.models import Model
+
+
+FUNCTION_DEFINITION_PATTERN = (
+    r"^def (?P<function_name>[a-zA-Z]{1}[a-zA-Z0-9_]+) *\(.*\) *:"
+)
+
+
+
+[docs] +class ModelConfig(BaseConfig): + """Configuration for all `Model`s.""" + + # Fields + class_name: str + arguments: Dict[str, Any] + + def __init__(self, **data: Any) -> None: + """Construct `ModelConfig`. + + Can be used for model configuration as code, thereby making model + construction more transparent and reproducible. Note that this does + *not* save any trainable weights, meaning this is only a configuration + for the model's hyperparameters. Any model instantiated from a + ModelConfig or file will be randomly initialised, and thus should be + trained. + + Examples: + In one session, do: + + >>> model = Model(...) + >>> model.config.dump() + arguments: + - (...): (...) + class_name: Model + >>> model.config.dump("model.yml") + + In another session, you can then do: + >>> model = Model.from_config("model.yml") + """ + # Parse any nested `ModelConfig` arguments + for arg in data["arguments"]: + value = data["arguments"][arg] + if isinstance(value, (tuple, list)): + for ix, elem in enumerate(value): + data["arguments"][arg][ + ix + ] = self._parse_if_model_config_entry(elem) + else: + data["arguments"][arg] = self._parse_if_model_config_entry( + value + ) + # Base class constructor + super().__init__(**data) + + def _is_model_config_entry(self, entry: Dict[str, Any]) -> bool: + """Check whether dictionary entry is a `ModelConfig`.""" + return ( + isinstance(entry, dict) + and len(entry) == 1 + and self.__class__.__name__ in entry + ) + + def _parse_if_model_config_entry( + self, entry: Dict[str, Any] + ) -> Union["ModelConfig", Any]: + """Parse dictionary entry to `ModelConfig`.""" + if self._is_model_config_entry(entry): + config_dict = entry[self.__class__.__name__] + config = self.__class__(**config_dict) + return config + else: + return entry + + def _construct_model( + self, + trust: bool = False, + load_modules: Optional[List[str]] = None, + ) -> "Model": + """Construct `Model` instance from `self` configuration. + + Used as the basis for `Model.from_config`. + """ + # Check(s) + if load_modules is None: + load_modules = ["torch"] + assert isinstance(load_modules, list) + + # Load any additional modules into the global namespace + for module in load_modules: + assert re.match("^[a-zA-Z_]+$", module) is not None + if module in globals(): + continue + exec(f"import {module}", globals()) + + # Get a lookup for all classes in `graphnet` + import graphnet.data + import graphnet.models + import graphnet.training + + namespace_classes = get_all_grapnet_classes( + graphnet.data, graphnet.models, graphnet.training + ) + + # Parse potential ModelConfig arguments + arguments = dict(**self.arguments) + arguments = traverse_and_apply( + arguments, + self._deserialise, + fn_kwargs={"trust": trust}, + ) + + # Construct model based on arguments + return namespace_classes[self.class_name](**arguments) + + @classmethod + def _deserialise(cls, obj: Any, trust: bool = False) -> Any: + if isinstance(obj, ModelConfig): + from graphnet.models import Model + + return Model.from_config(obj, trust=trust) + + elif isinstance(obj, str) and obj.startswith("!lambda"): + if trust: + source = obj[1:] + f = eval(source) + + # Save a copy of the source code attached to the callable, + # since the `inspect` module is not able to get the source code + # for functions that are not defined on file. + # See `self._serialise`. + f._source = source + return f + else: + raise ValueError( + "Constructing model containing a lambda function " + f"({obj}) with `trust=False`. If you trust the lambda " + "functions in this ModelConfig, set `trust=True` and " + "reconstruct the model again." + ) + + elif isinstance(obj, str) and obj.startswith("!function"): + if trust: + source = obj[10:] + match = re.match(FUNCTION_DEFINITION_PATTERN, source) + assert match + exec(source) + fn = eval(match.group("function_name")) + return fn + else: + raise ValueError( + f"Constructing model containing a function ({obj}) with " + "`trust=False`. If you trust the functions in this " + "ModelConfig, set `trust=True` and reconstruct the model " + "again." + ) + + elif isinstance(obj, str) and obj.startswith("!class"): + if trust: + module, class_name = obj.split()[1:] + exec(f"from {module} import {class_name}") + return eval(class_name) + else: + raise ValueError( + f"Constructing model containing a class ({obj}) with " + "`trust=False`. If you trust the class definitions in " + "this ModelConfig, set `trust=True` and reconstruct the " + "model again." + ) + elif isinstance(obj, str) and obj.startswith("torch"): + return eval(obj) + + else: + return obj + + @classmethod + def _serialise(cls, obj: Any) -> Any: + """Serialise `obj` to a format that can be saved to file.""" + if isinstance(obj, ModelConfig): + return obj.as_dict() + elif isinstance(obj, type): + return f"!class {obj.__module__} {obj.__name__}" + elif isinstance(obj, torch.dtype): + return obj.__str__() + elif isinstance(obj, Callable): # type: ignore[arg-type] + if hasattr(obj, "__name__") and obj.__name__ == "<lambda>": + if hasattr(obj, "_source"): + # If source code is set manually during deserialisation. + # See `self._deserialise`. + source = obj._source + else: + source = inspect.getsource(obj).split("=")[1].strip("\n ,") + + return "!" + source + else: + try: + source = inspect.getsource(obj) + match = re.match(FUNCTION_DEFINITION_PATTERN, source) + if match and match.group("function_name"): + return f"!function {source}" + else: + raise ValueError + except (TypeError, ValueError): + raise ValueError( + f"Object `{obj}` is callable but not a lambda or " + "regular function. Please wrap in a, e.g., lambda " + "function to allow for saving this function verbatim " + "in a model config file." + ) + + return obj + +
+[docs] + def as_dict(self) -> Dict[str, Dict[str, Any]]: + """Represent ModelConfig as a dict. + + This builds on `BaseModel.dict()` but wraps the output in a single-key + dictionary to make it unambiguous to identify model arguments that are + themselves models. + """ + config_dict = self.dict() + config_dict["arguments"] = traverse_and_apply( + self.arguments, self._serialise + ) + + return {self.__class__.__name__: config_dict}
+
+ + + +
+[docs] +def save_model_config(init_fn: Callable) -> Callable: + """Save the arguments to `__init__` functions as a member `ModelConfig`.""" + + def _replace_model_instance_with_config( + obj: Union["Model", Any] + ) -> Union[ModelConfig, Any]: + """Replace `Model` instances in `obj` with their `ModelConfig`.""" + from graphnet.models import Model + + if isinstance(obj, Model): + return obj.config + else: + return obj + + @wraps(init_fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + """Set `ModelConfig` after calling `init_fn`.""" + # Call wrapped method + ret = init_fn(self, *args, **kwargs) + + # Get all argument values, including defaults + cfg = get_all_argument_values(init_fn, *args, **kwargs) + + # Handle nested `Model`s, etc. + cfg = traverse_and_apply(cfg, _replace_model_instance_with_config) + + # Add `ModelConfig` as member variables + self._config = ModelConfig( + class_name=str(self.__class__.__name__), + arguments=dict(**cfg), + ) + + return ret + + return wrapper
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/parsing.html b/_modules/graphnet/utilities/config/parsing.html new file mode 100644 index 000000000..e1bff4682 --- /dev/null +++ b/_modules/graphnet/utilities/config/parsing.html @@ -0,0 +1,475 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.parsing — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.parsing

+"""Utility functions for parsing for using with Config-classes."""
+
+import itertools
+import pkgutil
+import types
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+)
+
+from graphnet.utilities.logging import Logger
+
+
+
+[docs] +def traverse_and_apply( + obj: Any, fn: Callable, fn_kwargs: Optional[Dict[str, Any]] = None +) -> Any: + """Apply `fn` to all elements in `obj`, resulting in same structure.""" + if isinstance(obj, (list, tuple)): + return [traverse_and_apply(elem, fn, fn_kwargs) for elem in obj] + elif isinstance(obj, dict): + return { + key: traverse_and_apply(val, fn, fn_kwargs) + for key, val in obj.items() + } + else: + if fn_kwargs is None: + fn_kwargs = {} + return fn(obj, **fn_kwargs)
+ + + +
+[docs] +def list_all_submodules(*packages: types.ModuleType) -> List[types.ModuleType]: + """List all submodules in `packages` recursively.""" + # Resolve one or more packages + if len(packages) > 1: + return list( + itertools.chain.from_iterable(map(list_all_submodules, packages)) + ) + else: + assert len(packages) == 1, "No packages specified" + package = packages[0] + + submodules: List[types.ModuleType] = [] + for _, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + module = __import__(module_name, fromlist="dummylist") + submodules.append(module) + if is_pkg: + submodules.extend(list_all_submodules(module)) + + return submodules
+ + + +
+[docs] +def get_all_grapnet_classes(*packages: types.ModuleType) -> Dict[str, type]: + """List all grapnet classes in `packages`.""" + submodules = list_all_submodules(*packages) + classes: Dict[str, type] = {} + for submodule in submodules: + new_classes = get_graphnet_classes(submodule) + for key in new_classes: + if key in classes and classes[key] != new_classes[key]: + Logger().warning( + f"Class {key} found in both {classes[key]} and " + f"{new_classes[key]}. Keeping first instance. " + "Consider renaming." + ) + classes.update(new_classes) + + return classes
+ + + +
+[docs] +def is_graphnet_module(obj: types.ModuleType) -> bool: + """Return whether `obj` is a module in graphnet.""" + return isinstance(obj, types.ModuleType) and obj.__name__.startswith( + "graphnet." + )
+ + + +
+[docs] +def is_graphnet_class(obj: type) -> bool: + """Return whether `obj` is a class in graphnet.""" + return isinstance(obj, type) and obj.__module__.startswith("graphnet.")
+ + + +
+[docs] +def get_graphnet_classes(module: types.ModuleType) -> Dict[str, type]: + """Return a lookup of all graphnet class names in `module`.""" + if not is_graphnet_module(module): + Logger().info(f"{module} is not a graphnet module") + return {} + classes = { + key: val + for key, val in module.__dict__.items() + if is_graphnet_class(val) + } + return classes
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/config/training_config.html b/_modules/graphnet/utilities/config/training_config.html new file mode 100644 index 000000000..add0468a2 --- /dev/null +++ b/_modules/graphnet/utilities/config/training_config.html @@ -0,0 +1,378 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.config.training_config — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.config.training_config

+"""Config classes for the `graphnet.training` module."""
+
+from typing import Any, Dict, List, Union
+
+from graphnet.utilities.config.base_config import BaseConfig
+
+
+
+[docs] +class TrainingConfig(BaseConfig): + """Configuration for all trainings.""" + + # Fields + target: Union[str, List[str]] + early_stopping_patience: int + fit: Dict[str, Any] + dataloader: Dict[str, Any]
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/graphnet/utilities/filesys.html b/_modules/graphnet/utilities/filesys.html index ef2600ad3..121c9bfe5 100644 --- a/_modules/graphnet/utilities/filesys.html +++ b/_modules/graphnet/utilities/filesys.html @@ -447,7 +447,7 @@

Source code for graphnet. Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/utilities/imports.html b/_modules/graphnet/utilities/imports.html index 40699ad6e..80debdadf 100644 --- a/_modules/graphnet/utilities/imports.html +++ b/_modules/graphnet/utilities/imports.html @@ -420,7 +420,7 @@

Source code for graphnet. Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/utilities/logging.html b/_modules/graphnet/utilities/logging.html index cb014897d..aa1b8a091 100644 --- a/_modules/graphnet/utilities/logging.html +++ b/_modules/graphnet/utilities/logging.html @@ -630,7 +630,7 @@

Source code for graphnet. Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/_modules/graphnet/utilities/maths.html b/_modules/graphnet/utilities/maths.html new file mode 100644 index 000000000..607211db6 --- /dev/null +++ b/_modules/graphnet/utilities/maths.html @@ -0,0 +1,371 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.utilities.maths — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.utilities.maths

+"""Collection of assorted "maths-like" functions."""
+
+import torch
+
+
+
+[docs] +def eps_like(tensor: torch.Tensor) -> torch.Tensor: + """Return `eps` matching `tensor`'s dtype.""" + return torch.finfo(tensor.dtype).eps
+ +
+ +
+
+
+
+
+
+ + +
+ + + + \ No newline at end of file diff --git a/_modules/index.html b/_modules/index.html index 89db4987e..120145b62 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -323,6 +323,11 @@

All modules for which code is available

@@ -375,7 +414,7 @@

All modules for which code is available

Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/about.html b/about.html index 3ebaa11ee..6422d9cf3 100644 --- a/about.html +++ b/about.html @@ -392,7 +392,7 @@

AcknowledgementsSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.constants.html b/api/graphnet.constants.html index 5cc170ec7..0a215dbf4 100644 --- a/api/graphnet.constants.html +++ b/api/graphnet.constants.html @@ -422,7 +422,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.constants.html b/api/graphnet.data.constants.html index 36edc3f62..32d33f303 100644 --- a/api/graphnet.data.constants.html +++ b/api/graphnet.data.constants.html @@ -700,7 +700,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataconverter.html b/api/graphnet.data.dataconverter.html index 61d9300a1..40c210634 100644 --- a/api/graphnet.data.dataconverter.html +++ b/api/graphnet.data.dataconverter.html @@ -820,7 +820,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataloader.html b/api/graphnet.data.dataloader.html index 8ad405a2b..98bb280f5 100644 --- a/api/graphnet.data.dataloader.html +++ b/api/graphnet.data.dataloader.html @@ -365,11 +365,54 @@ + +
  • @@ -436,7 +479,22 @@
    @@ -446,8 +504,73 @@
    -
    -

    dataloader

    +
    +

    dataloader

    +

    Base Dataloader class(es) used in graphnet.

    +
    +
    +graphnet.data.dataloader.collate_fn(graphs)[source]
    +

    Remove graphs with less than two DOM hits.

    +

    Should not occur in “production.

    +
    +
    Return type:
    +

    Batch

    +
    +
    Parameters:
    +

    graphs (List[Data]) –

    +
    +
    +
    +
    +
    +graphnet.data.dataloader.do_shuffle(selection_name)[source]
    +

    Check whether to shuffle selection with name selection_name.

    +
    +
    Return type:
    +

    bool

    +
    +
    Parameters:
    +

    selection_name (str) –

    +
    +
    +
    +
    +
    +class graphnet.data.dataloader.DataLoader(dataset, batch_size, shuffle, num_workers, persistent_workers, collate_fn=<function collate_fn>, prefetch_factor, **kwargs)[source]
    +

    Bases: DataLoader

    +

    Class for loading data from a Dataset.

    +

    Construct DataLoader.

    +
    +
    Parameters:
    +
      +
    • dataset (Dataset[T_co]) –

    • +
    • batch_size (int | None) –

    • +
    • shuffle (bool) –

    • +
    • num_workers (int) –

    • +
    • persistent_workers (bool) –

    • +
    • collate_fn (Callable) –

    • +
    • prefetch_factor (int | None) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +classmethod from_dataset_config(config, **kwargs)[source]
    +

    Construct DataLoader`s based on selections in `DatasetConfig.

    +
    +
    Return type:
    +

    Union[DataLoader, Dict[str, DataLoader]]

    +
    +
    Parameters:
    +
    +
    +
    +
    +
    @@ -497,7 +620,7 @@

    dataloader Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.dataset.html b/api/graphnet.data.dataset.dataset.html index 0750d2bf5..ac7ed7ec2 100644 --- a/api/graphnet.data.dataset.dataset.html +++ b/api/graphnet.data.dataset.dataset.html @@ -336,11 +336,117 @@ + +

  • @@ -458,7 +564,36 @@ @@ -468,8 +603,197 @@
    -
    -

    dataset

    +
    +

    dataset

    +

    Base Dataset class(es) used in GraphNeT.

    +
    +
    +exception graphnet.data.dataset.dataset.ColumnMissingException[source]
    +

    Bases: Exception

    +

    Exception to indicate a missing column in a dataset.

    +
    +
    +
    +graphnet.data.dataset.dataset.load_module(class_name)[source]
    +

    Load graphnet module from string name.

    +
    +
    Parameters:
    +

    class_name (str) – name of class

    +
    +
    Return type:
    +

    Type

    +
    +
    Returns:
    +

    graphnet module.

    +
    +
    +
    +
    +
    +graphnet.data.dataset.dataset.parse_graph_definition(cfg)[source]
    +

    Construct GraphDefinition from DatasetConfig.

    +
    +
    Return type:
    +

    GraphDefinition

    +
    +
    Parameters:
    +

    cfg (dict) –

    +
    +
    +
    +
    +
    +class graphnet.data.dataset.dataset.Dataset(path, graph_definition, pulsemaps, features, truth, *, node_truth, index_column, truth_table, node_truth_table, string_selection, selection, dtype, loss_weight_table, loss_weight_column, loss_weight_default_value, seed)[source]
    +

    Bases: Logger, Configurable, Dataset, ABC

    +

    Base Dataset class for reading from any intermediate file format.

    +

    Construct Dataset.

    +
    +
    Parameters:
    +
      +
    • path (Union[str, List[str]]) – Path to the file(s) from which this Dataset should read.

    • +
    • pulsemaps (Union[str, List[str]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.

    • +
    • features (List[str]) – List of columns in the input files that should be used as +node features on the graph objects.

    • +
    • truth (List[str]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.

    • +
    • node_truth (Optional[List[str]], default: None) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.

    • +
    • index_column (str, default: 'event_no') – Name of the column in the input files that contains +unique indicies to identify and map events across tables.

    • +
    • truth_table (str, default: 'truth') – Name of the table containing event-level truth +information.

    • +
    • node_truth_table (Optional[str], default: None) – Name of the table containing node-level truth +information.

    • +
    • string_selection (Optional[List[int]], default: None) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.

    • +
    • selection (Union[str, List[int], List[List[int]], None], default: None) – The events that should be read. This can be given either +as list of indicies (in index_column); or a string-based +selection used to query the Dataset for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.

    • +
    • dtype (dtype, default: torch.float32) – Type of the feature tensor on the graph objects returned.

    • +
    • loss_weight_table (Optional[str], default: None) – Name of the table containing per-event loss +weights.

    • +
    • loss_weight_column (Optional[str], default: None) – Name of the column in loss_weight_table +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.

    • +
    • loss_weight_default_value (Optional[float], default: None) – Default per-event loss weight. +NOTE: This default value is only applied when +loss_weight_table and loss_weight_column are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.

    • +
    • seed (Optional[int], default: None) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +“10000 random events ~ event_no % 5 > 0” or “20% random +events ~ event_no % 5 > 0”).

    • +
    • graph_definition (GraphDefinition) – Method that defines the graph representation.

    • +
    +
    +
    +
    +
    +classmethod from_config(source)[source]
    +

    Construct Dataset instance from source configuration.

    +
    +
    Return type:
    +

    Union[Dataset, EnsembleDataset, Dict[str, Dataset], Dict[str, EnsembleDataset]]

    +
    +
    Parameters:
    +

    source (DatasetConfig | str) –

    +
    +
    +
    +
    +
    +classmethod concatenate(datasets)[source]
    +

    Concatenate multiple `Dataset`s into one instance.

    +
    +
    Return type:
    +

    EnsembleDataset

    +
    +
    Parameters:
    +

    datasets (List[Dataset]) –

    +
    +
    +
    +
    +
    +property path: str | List[str]
    +

    Path to the file(s) from which this Dataset reads.

    +
    +
    +
    +property truth_table: str
    +

    Name of the table containing event-level truth information.

    +
    +
    +
    +abstract query_table(table, columns, sequential_index, selection)[source]
    +

    Query a table at a specific index, optionally with some selection.

    +
    +
    Parameters:
    +
      +
    • table (str) – Table to be queried.

    • +
    • columns (Union[List[str], str]) – Columns to read out.

    • +
    • sequential_index (Optional[int], default: None) – Sequentially numbered index +(i.e. in [0,len(self))) of the event to query. This _may_ +differ from the indexation used in self._indices. If no value +is provided, the entire column is returned.

    • +
    • selection (Optional[str], default: None) – Selection to be imposed before reading out data. +Defaults to None.

    • +
    +
    +
    Return type:
    +

    List[Tuple[Any, ...]]

    +
    +
    Returns:
    +

    +
    List of tuples containing the values in columns. If the table

    contains only scalar data for columns, a list of length 1 is +returned

    +
    +
    +

    +
    +
    Raises:
    +

    ColumnMissingException – If one or more element in columns is not + present in table.

    +
    +
    +
    +
    +
    +add_label(fn, key)[source]
    +

    Add custom graph label define using function fn.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +
      +
    • fn (Callable[[Data], Any]) –

    • +
    • key (str | None) –

    • +
    +
    +
    +
    +
    +
    +
    +class graphnet.data.dataset.dataset.EnsembleDataset(datasets)[source]
    +

    Bases: ConcatDataset

    +

    Construct a single dataset from a collection of datasets.

    +

    Construct a single dataset from a collection of datasets.

    +
    +
    Parameters:
    +

    datasets (Iterable[Dataset]) – A collection of Datasets

    +
    +
    +
    @@ -519,7 +843,7 @@

    datasetSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.html b/api/graphnet.data.dataset.html index 9c721d294..a4b095f8e 100644 --- a/api/graphnet.data.dataset.html +++ b/api/graphnet.data.dataset.html @@ -467,8 +467,9 @@
    -
    -

    dataset

    +
    +

    dataset

    +

    Dataset classes for training in GraphNeT.

    Subpackages

    @@ -538,7 +546,7 @@

    dataset Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.parquet.html b/api/graphnet.data.dataset.parquet.html index 873efa2c0..492684fad 100644 --- a/api/graphnet.data.dataset.parquet.html +++ b/api/graphnet.data.dataset.parquet.html @@ -475,12 +475,16 @@
    -
    -

    parquet

    +
    +

    parquet

    +

    Datasets using parquet backend.

    Submodules

    @@ -532,7 +536,7 @@

    parquetSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.parquet.parquet_dataset.html b/api/graphnet.data.dataset.parquet.parquet_dataset.html index 67cfc6f08..3598486cf 100644 --- a/api/graphnet.data.dataset.parquet.parquet_dataset.html +++ b/api/graphnet.data.dataset.parquet.parquet_dataset.html @@ -328,11 +328,36 @@ + + @@ -466,7 +491,18 @@
    @@ -476,8 +512,82 @@
    -
    -

    parquet_dataset

    +
    +

    parquet_dataset

    +

    Dataset class(es) for reading from Parquet files.

    +
    +
    +class graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset(path, graph_definition, pulsemaps, features, truth, *, node_truth, index_column, truth_table, node_truth_table, string_selection, selection, dtype, loss_weight_table, loss_weight_column, loss_weight_default_value, seed)[source]
    +

    Bases: Dataset

    +

    Pytorch dataset for reading from Parquet files.

    +

    Construct Dataset.

    +
    +
    Parameters:
    +
      +
    • path (Union[str, List[str]]) – Path to the file(s) from which this Dataset should read.

    • +
    • pulsemaps (Union[str, List[str]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.

    • +
    • features (List[str]) – List of columns in the input files that should be used as +node features on the graph objects.

    • +
    • truth (List[str]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.

    • +
    • node_truth (Optional[List[str]], default: None) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.

    • +
    • index_column (str, default: 'event_no') – Name of the column in the input files that contains +unique indicies to identify and map events across tables.

    • +
    • truth_table (str, default: 'truth') – Name of the table containing event-level truth +information.

    • +
    • node_truth_table (Optional[str], default: None) – Name of the table containing node-level truth +information.

    • +
    • string_selection (Optional[List[int]], default: None) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.

    • +
    • selection (Union[str, List[int], List[List[int]], None], default: None) – The events that should be read. This can be given either +as list of indicies (in index_column); or a string-based +selection used to query the Dataset for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.

    • +
    • dtype (dtype, default: torch.float32) – Type of the feature tensor on the graph objects returned.

    • +
    • loss_weight_table (Optional[str], default: None) – Name of the table containing per-event loss +weights.

    • +
    • loss_weight_column (Optional[str], default: None) – Name of the column in loss_weight_table +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.

    • +
    • loss_weight_default_value (Optional[float], default: None) – Default per-event loss weight. +NOTE: This default value is only applied when +loss_weight_table and loss_weight_column are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.

    • +
    • seed (Optional[int], default: None) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +“10000 random events ~ event_no % 5 > 0” or “20% random +events ~ event_no % 5 > 0”).

    • +
    • graph_definition (GraphDefinition) – Method that defines the graph representation.

    • +
    +
    +
    +
    +
    +query_table(table, columns, sequential_index, selection)[source]
    +

    Query table at a specific index, optionally with some selection.

    +
    +
    Return type:
    +

    List[Tuple[Any, ...]]

    +
    +
    Parameters:
    +
      +
    • table (str) –

    • +
    • columns (List[str] | str) –

    • +
    • sequential_index (int | None) –

    • +
    • selection (str | None) –

    • +
    +
    +
    +
    +
    @@ -527,7 +637,7 @@

    parquet_da

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.sqlite.html b/api/graphnet.data.dataset.sqlite.html index 9cfc9a2a5..8ea0f5783 100644 --- a/api/graphnet.data.dataset.sqlite.html +++ b/api/graphnet.data.dataset.sqlite.html @@ -482,13 +482,20 @@
    -
    -

    sqlite

    +
    +

    sqlite

    +

    Datasets using SQLite backend.

    Submodules

    @@ -540,7 +547,7 @@

    sqlite Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html index 36b0988a6..7d25db896 100644 --- a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html +++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html @@ -335,11 +335,36 @@ + +
  • @@ -473,7 +498,18 @@
    @@ -483,8 +519,82 @@
    -
    -

    sqlite_dataset

    +
    +

    sqlite_dataset

    +

    Dataset class(es) for reading data from SQLite databases.

    +
    +
    +class graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset(path, graph_definition, pulsemaps, features, truth, *, node_truth, index_column, truth_table, node_truth_table, string_selection, selection, dtype, loss_weight_table, loss_weight_column, loss_weight_default_value, seed)[source]
    +

    Bases: Dataset

    +

    Pytorch dataset for reading data from SQLite databases.

    +

    Construct Dataset.

    +
    +
    Parameters:
    +
      +
    • path (Union[str, List[str]]) – Path to the file(s) from which this Dataset should read.

    • +
    • pulsemaps (Union[str, List[str]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.

    • +
    • features (List[str]) – List of columns in the input files that should be used as +node features on the graph objects.

    • +
    • truth (List[str]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.

    • +
    • node_truth (Optional[List[str]], default: None) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.

    • +
    • index_column (str, default: 'event_no') – Name of the column in the input files that contains +unique indicies to identify and map events across tables.

    • +
    • truth_table (str, default: 'truth') – Name of the table containing event-level truth +information.

    • +
    • node_truth_table (Optional[str], default: None) – Name of the table containing node-level truth +information.

    • +
    • string_selection (Optional[List[int]], default: None) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.

    • +
    • selection (Union[str, List[int], List[List[int]], None], default: None) – The events that should be read. This can be given either +as list of indicies (in index_column); or a string-based +selection used to query the Dataset for events passing the +selection. Defaults to None, meaning that all events in the +input files are read.

    • +
    • dtype (dtype, default: torch.float32) – Type of the feature tensor on the graph objects returned.

    • +
    • loss_weight_table (Optional[str], default: None) – Name of the table containing per-event loss +weights.

    • +
    • loss_weight_column (Optional[str], default: None) – Name of the column in loss_weight_table +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.

    • +
    • loss_weight_default_value (Optional[float], default: None) – Default per-event loss weight. +NOTE: This default value is only applied when +loss_weight_table and loss_weight_column are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.

    • +
    • seed (Optional[int], default: None) – Random number generator seed, used for selecting a random +subset of events when resolving a string-based selection (e.g., +“10000 random events ~ event_no % 5 > 0” or “20% random +events ~ event_no % 5 > 0”).

    • +
    • graph_definition (GraphDefinition) – Method that defines the graph representation.

    • +
    +
    +
    +
    +
    +query_table(table, columns, sequential_index, selection)[source]
    +

    Query table at a specific index, optionally with some selection.

    +
    +
    Return type:
    +

    List[Tuple[Any, ...]]

    +
    +
    Parameters:
    +
      +
    • table (str) –

    • +
    • columns (List[str] | str) –

    • +
    • sequential_index (int | None) –

    • +
    • selection (str | None) –

    • +
    +
    +
    +
    +
    @@ -534,7 +644,7 @@

    sqlite_datas

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html index e5b66ee43..f46386964 100644 --- a/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html +++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.html @@ -342,11 +342,25 @@ + +
  • @@ -473,7 +487,14 @@ @@ -483,8 +504,65 @@
    -
    -

    sqlite_dataset_perturbed

    +
    +

    sqlite_dataset_perturbed

    +

    Dataset class(es) for reading perturbed data from SQLite databases.

    +
    +
    +class graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed(path, pulsemaps, features, truth, *, perturbation_dict, node_truth, index_column, truth_table, node_truth_table, string_selection, selection, dtype, loss_weight_table, loss_weight_column, loss_weight_default_value, seed)[source]
    +

    Bases: SQLiteDataset

    +

    Pytorch dataset for reading perturbed data from SQLite databases.

    +

    This including a pre-processing step, where the input data is randomly +perturbed according to given per-feature “noise” levels. This is intended +to test the stability of a trained model under small changes to the input +parameters.

    +

    Construct SQLiteDatasetPerturbed.

    +
    +
    Parameters:
    +
      +
    • path (Union[str, List[str]]) – Path to the file(s) from which this Dataset should read.

    • +
    • pulsemaps (Union[str, List[str]]) – Name(s) of the pulse map series that should be used to +construct the nodes on the individual graph objects, and their +features. Multiple pulse series maps can be used, e.g., when +different DOM types are stored in different maps.

    • +
    • features (List[str]) – List of columns in the input files that should be used as +node features on the graph objects.

    • +
    • truth (List[str]) – List of event-level columns in the input files that should +be used added as attributes on the graph objects.

    • +
    • perturbation_dict (Dict[str, float]) – Dictionary mapping a feature +name to a standard deviation according to which the values for +this feature should be randomly perturbed.

    • +
    • node_truth (Optional[List[str]], default: None) – List of node-level columns in the input files that +should be used added as attributes on the graph objects.

    • +
    • index_column (str, default: 'event_no') – Name of the column in the input files that contains +unique indicies to identify and map events across tables.

    • +
    • truth_table (str, default: 'truth') – Name of the table containing event-level truth +information.

    • +
    • node_truth_table (Optional[str], default: None) – Name of the table containing node-level truth +information.

    • +
    • string_selection (Optional[List[int]], default: None) – Subset of strings for which data should be read +and used to construct graph objects. Defaults to None, meaning +all strings for which data exists are used.

    • +
    • selection (Optional[List[int]], default: None) – List of indicies (in index_column) of the events in +the input files that should be read. Defaults to None, meaning +that all events in the input files are read.

    • +
    • dtype (dtype, default: torch.float32) – Type of the feature tensor on the graph objects returned.

    • +
    • loss_weight_table (Optional[str], default: None) – Name of the table containing per-event loss +weights.

    • +
    • loss_weight_column (Optional[str], default: None) – Name of the column in loss_weight_table +containing per-event loss weights. This is also the name of the +corresponding attribute assigned to the graph object.

    • +
    • loss_weight_default_value (Optional[float], default: None) – Default per-event loss weight. +NOTE: This default value is only applied when +loss_weight_table and loss_weight_column are specified, and +in this case to events with no value in the corresponding +table/column. That is, if no per-event loss weight table/column +is provided, this value is ignored. Defaults to None.

    • +
    • seed (Union[int, Generator, None], default: None) – Optional seed for random number generation. Defaults to None.

    • +
    +
    +
    +
    @@ -534,7 +612,7 @@

    sq

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.html b/api/graphnet.data.extractors.html index df9ce380e..5810ea8fc 100644 --- a/api/graphnet.data.extractors.html +++ b/api/graphnet.data.extractors.html @@ -658,7 +658,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3extractor.html b/api/graphnet.data.extractors.i3extractor.html index 884ec4690..4671c09f7 100644 --- a/api/graphnet.data.extractors.i3extractor.html +++ b/api/graphnet.data.extractors.i3extractor.html @@ -732,7 +732,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3featureextractor.html b/api/graphnet.data.extractors.i3featureextractor.html index fd27d93af..520605dfe 100644 --- a/api/graphnet.data.extractors.i3featureextractor.html +++ b/api/graphnet.data.extractors.i3featureextractor.html @@ -720,7 +720,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3genericextractor.html b/api/graphnet.data.extractors.i3genericextractor.html index fffb0ae7a..2915de7ce 100644 --- a/api/graphnet.data.extractors.i3genericextractor.html +++ b/api/graphnet.data.extractors.i3genericextractor.html @@ -639,7 +639,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3hybridrecoextractor.html b/api/graphnet.data.extractors.i3hybridrecoextractor.html index c13a47727..e23c93bd9 100644 --- a/api/graphnet.data.extractors.i3hybridrecoextractor.html +++ b/api/graphnet.data.extractors.i3hybridrecoextractor.html @@ -623,7 +623,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html index d87d2a70a..f9a853695 100644 --- a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html +++ b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html @@ -626,7 +626,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3particleextractor.html b/api/graphnet.data.extractors.i3particleextractor.html index ed80239fc..fbdb86f6e 100644 --- a/api/graphnet.data.extractors.i3particleextractor.html +++ b/api/graphnet.data.extractors.i3particleextractor.html @@ -625,7 +625,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3pisaextractor.html b/api/graphnet.data.extractors.i3pisaextractor.html index 82a7edbf0..4471bf09a 100644 --- a/api/graphnet.data.extractors.i3pisaextractor.html +++ b/api/graphnet.data.extractors.i3pisaextractor.html @@ -623,7 +623,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3quesoextractor.html b/api/graphnet.data.extractors.i3quesoextractor.html index 155c83c34..649bb0171 100644 --- a/api/graphnet.data.extractors.i3quesoextractor.html +++ b/api/graphnet.data.extractors.i3quesoextractor.html @@ -626,7 +626,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3retroextractor.html b/api/graphnet.data.extractors.i3retroextractor.html index 579e01cb4..86f076405 100644 --- a/api/graphnet.data.extractors.i3retroextractor.html +++ b/api/graphnet.data.extractors.i3retroextractor.html @@ -623,7 +623,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3splinempeextractor.html b/api/graphnet.data.extractors.i3splinempeextractor.html index fd7c55164..f229b0090 100644 --- a/api/graphnet.data.extractors.i3splinempeextractor.html +++ b/api/graphnet.data.extractors.i3splinempeextractor.html @@ -623,7 +623,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3truthextractor.html b/api/graphnet.data.extractors.i3truthextractor.html index 8e211b71a..213393e35 100644 --- a/api/graphnet.data.extractors.i3truthextractor.html +++ b/api/graphnet.data.extractors.i3truthextractor.html @@ -630,7 +630,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.i3tumextractor.html b/api/graphnet.data.extractors.i3tumextractor.html index 4b4ad2d29..326699de0 100644 --- a/api/graphnet.data.extractors.i3tumextractor.html +++ b/api/graphnet.data.extractors.i3tumextractor.html @@ -623,7 +623,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.utilities.collections.html b/api/graphnet.data.extractors.utilities.collections.html index d963c742e..0970c915a 100644 --- a/api/graphnet.data.extractors.utilities.collections.html +++ b/api/graphnet.data.extractors.utilities.collections.html @@ -708,7 +708,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.utilities.frames.html b/api/graphnet.data.extractors.utilities.frames.html index f895882b1..2fcdda991 100644 --- a/api/graphnet.data.extractors.utilities.frames.html +++ b/api/graphnet.data.extractors.utilities.frames.html @@ -709,7 +709,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.utilities.html b/api/graphnet.data.extractors.utilities.html index da9334938..d7361ced0 100644 --- a/api/graphnet.data.extractors.utilities.html +++ b/api/graphnet.data.extractors.utilities.html @@ -640,7 +640,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.extractors.utilities.types.html b/api/graphnet.data.extractors.utilities.types.html index 6c8154f02..10266dd2d 100644 --- a/api/graphnet.data.extractors.utilities.types.html +++ b/api/graphnet.data.extractors.utilities.types.html @@ -867,7 +867,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.html b/api/graphnet.data.html index dff3daeae..57b5f245a 100644 --- a/api/graphnet.data.html +++ b/api/graphnet.data.html @@ -507,8 +507,16 @@
  • DataConverter
  • -
  • dataloader
  • -
  • pipeline
  • +
  • dataloader +
  • +
  • pipeline +
  • @@ -560,7 +568,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.parquet.html b/api/graphnet.data.parquet.html index 31148efc3..ebb6a3f5f 100644 --- a/api/graphnet.data.parquet.html +++ b/api/graphnet.data.parquet.html @@ -514,7 +514,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.parquet.parquet_dataconverter.html b/api/graphnet.data.parquet.parquet_dataconverter.html index 3b17b9603..a0f810016 100644 --- a/api/graphnet.data.parquet.parquet_dataconverter.html +++ b/api/graphnet.data.parquet.parquet_dataconverter.html @@ -665,7 +665,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.pipeline.html b/api/graphnet.data.pipeline.html index 633c55e35..3e14a0e83 100644 --- a/api/graphnet.data.pipeline.html +++ b/api/graphnet.data.pipeline.html @@ -372,11 +372,25 @@ + + @@ -436,7 +450,14 @@
    @@ -446,8 +467,36 @@
    -
    -

    pipeline

    +
    +

    pipeline

    +

    Class(es) used for analysis in PISA.

    +
    +
    +class graphnet.data.pipeline.InSQLitePipeline(module_dict, features, truth, device, retro_table_name, outdir, batch_size, n_workers, pipeline_name)[source]
    +

    Bases: ABC, Logger

    +

    Create a SQLite database for PISA analysis.

    +

    The database will contain truth and GNN predictions and, if available, +RETRO reconstructions.

    +

    Initialise the pipeline.

    +
    +
    Parameters:
    +
      +
    • module_dict (Dict) – A dictionary with GNN modules from GraphNet. E.g. +{‘energy’: gnn_module_for_energy_regression}

    • +
    • features (List[str]) – List of input features for the GNN modules.

    • +
    • truth (List[str]) – List of truth for the GNN ModuleList.

    • +
    • device (device) – The device used for computation.

    • +
    • retro_table_name (str, default: 'retro') – Name of the retro table for.

    • +
    • outdir (Optional[str], default: None) – the directory in which the pipeline database will be +stored.

    • +
    • batch_size (int, default: 100) – Batch size for inference.

    • +
    • n_workers (int, default: 10) – Number of workers used in dataloading.

    • +
    • pipeline_name (str, default: 'pipeline') – Name of the pipeline. If such a pipeline already +exists, an error will be prompted to avoid overwriting.

    • +
    +
    +
    +
    @@ -497,7 +546,7 @@

    pipeline Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.sqlite.html b/api/graphnet.data.sqlite.html index 16343c670..cd932166c 100644 --- a/api/graphnet.data.sqlite.html +++ b/api/graphnet.data.sqlite.html @@ -534,7 +534,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.sqlite.sqlite_dataconverter.html b/api/graphnet.data.sqlite.sqlite_dataconverter.html index 302c9e50d..dfbb879df 100644 --- a/api/graphnet.data.sqlite.sqlite_dataconverter.html +++ b/api/graphnet.data.sqlite.sqlite_dataconverter.html @@ -776,7 +776,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.sqlite.sqlite_utilities.html b/api/graphnet.data.sqlite.sqlite_utilities.html index b11e73e50..514fbdacf 100644 --- a/api/graphnet.data.sqlite.sqlite_utilities.html +++ b/api/graphnet.data.sqlite.sqlite_utilities.html @@ -727,7 +727,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.utilities.html b/api/graphnet.data.utilities.html index 35675598e..894827708 100644 --- a/api/graphnet.data.utilities.html +++ b/api/graphnet.data.utilities.html @@ -536,7 +536,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.utilities.parquet_to_sqlite.html b/api/graphnet.data.utilities.parquet_to_sqlite.html index eca0271ac..86d85b9eb 100644 --- a/api/graphnet.data.utilities.parquet_to_sqlite.html +++ b/api/graphnet.data.utilities.parquet_to_sqlite.html @@ -591,7 +591,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.utilities.random.html b/api/graphnet.data.utilities.random.html index e3b1572b6..ee45ec2a4 100644 --- a/api/graphnet.data.utilities.random.html +++ b/api/graphnet.data.utilities.random.html @@ -563,7 +563,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.data.utilities.string_selection_resolver.html b/api/graphnet.data.utilities.string_selection_resolver.html index 2ee151d57..2be85aed2 100644 --- a/api/graphnet.data.utilities.string_selection_resolver.html +++ b/api/graphnet.data.utilities.string_selection_resolver.html @@ -552,7 +552,7 @@
    Parameters:
      -
    • dataset (Dataset) –

    • +
    • dataset (Dataset) –

    • index_column (str) –

    • seed (int | None) –

    • use_cache (bool) –

    • @@ -626,7 +626,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.deployment.html b/api/graphnet.deployment.html index 534adda23..a21590d71 100644 --- a/api/graphnet.deployment.html +++ b/api/graphnet.deployment.html @@ -453,7 +453,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.deployment.i3modules.deployer.html b/api/graphnet.deployment.i3modules.deployer.html index 5a1c55118..c29615e41 100644 --- a/api/graphnet.deployment.i3modules.deployer.html +++ b/api/graphnet.deployment.i3modules.deployer.html @@ -456,7 +456,7 @@

      deployer Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.deployment.i3modules.graphnet_module.html b/api/graphnet.deployment.i3modules.graphnet_module.html index 0005fa3ad..ac93cd1ea 100644 --- a/api/graphnet.deployment.i3modules.graphnet_module.html +++ b/api/graphnet.deployment.i3modules.graphnet_module.html @@ -336,11 +336,43 @@ + +

    @@ -395,7 +427,18 @@ @@ -405,8 +448,94 @@
    -
    -

    graphnet_module

    +
    +

    graphnet_module

    +

    Class(es) for deploying GraphNeT models in icetray as I3Modules.

    +
    +
    +class graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module(graph_definition, pulsemap, features, pulsemap_extractor, gcd_file)[source]
    +

    Bases: object

    +

    Base I3 Module for GraphNeT.

    +

    Contains methods for extracting pulsemaps, producing graphs and writing to +frames.

    +

    I3Module Constructor.

    +
    +
    Parameters:
    +
      +
    • graph_definition (GraphDefinition) – An instance of GraphDefinition. E.g. KNNGraph.

    • +
    • pulsemap (str) – the pulse map on which the module functions

    • +
    • features (List[str]) – the features that is used from the pulse map. +E.g. [dom_x, dom_y, dom_z, charge]

    • +
    • pulsemap_extractor (Union[List[I3FeatureExtractor], I3FeatureExtractor]) – The I3FeatureExtractor used to extract the +pulsemap from the I3Frames

    • +
    • gcd_file (str) – Path to the associated gcd-file.

    • +
    +
    +
    +
    +
    +
    +class graphnet.deployment.i3modules.graphnet_module.I3InferenceModule(pulsemap, features, pulsemap_extractor, model_config, state_dict, model_name, gcd_file, prediction_columns)[source]
    +

    Bases: GraphNeTI3Module

    +

    General class for inference on i3 frames.

    +

    General class for inference on I3Frames (physics).

    +
    +
    Parameters:
    +
      +
    • pulsemap (str) – the pulsmap that the model is expecting as input.

    • +
    • features (List[str]) – the features of the pulsemap that the model is expecting.

    • +
    • pulsemap_extractor (Union[List[I3FeatureExtractor], I3FeatureExtractor]) – The extractor used to extract the pulsemap.

    • +
    • model_config (Union[ModelConfig, str]) – The ModelConfig (or path to it) that summarizes the +model used for inference.

    • +
    • state_dict (str) – Path to state_dict containing the learned weights.

    • +
    • model_name (str) – The name used for the model. Will help define the +named entry in the I3Frame. E.g. “dynedge”.

    • +
    • gcd_file (str) – path to associated gcd file.

    • +
    • prediction_columns (Union[str, List[str], None], default: None) –

      column names for the predictions of the model. +Will help define the named entry in the I3Frame.

      +
      +

      E.g. [‘energy_reco’]. Optional.

      +
      +

    • +
    +
    +
    +
    +
    +
    +class graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule(pulsemap, features, pulsemap_extractor, model_config, state_dict, model_name, *, gcd_file, threshold, discard_empty_events, prediction_columns)[source]
    +

    Bases: I3InferenceModule

    +

    A specialized module for pulse cleaning.

    +

    It is assumed that the model provided has been trained for this.

    +

    General class for inference on I3Frames (physics).

    +
    +
    Parameters:
    +
      +
    • pulsemap (str) – the pulsmap that the model is expecting as input +(the one that is being cleaned).

    • +
    • features (List[str]) – the features of the pulsemap that the model is expecting.

    • +
    • pulsemap_extractor (Union[List[I3FeatureExtractor], I3FeatureExtractor]) – The extractor used to extract the pulsemap.

    • +
    • model_config (str) – The ModelConfig (or path to it) that summarizes the +model used for inference.

    • +
    • state_dict (str) – Path to state_dict containing the learned weights.

    • +
    • model_name (str) – The name used for the model. Will help define the named +entry in the I3Frame. E.g. “dynedge”.

    • +
    • gcd_file (str) – path to associated gcd file.

    • +
    • threshold (float, default: 0.7) – the threshold for being considered a positive case. +E.g., predictions >= threshold will be considered +to be signal, all else noise.

    • +
    • discard_empty_events (bool, default: False) – When true, this flag will eliminate events +whose cleaned pulse series are empty. Can be used +to speed up processing especially for noise +simulation, since it will not do any writing or +further calculations.

    • +
    • prediction_columns (Union[str, List[str], None], default: None) – column names for the predictions of the model. +Will help define the named entry in the I3Frame. +E.g. [‘energy_reco’]. Optional.

    • +
    +
    +
    +
    @@ -456,7 +585,7 @@

    graphnet_m

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.deployment.i3modules.html b/api/graphnet.deployment.i3modules.html index cb57f2df6..975ce8207 100644 --- a/api/graphnet.deployment.i3modules.html +++ b/api/graphnet.deployment.i3modules.html @@ -410,7 +410,12 @@

    i3modules @@ -462,7 +467,7 @@

    i3modulesSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.html b/api/graphnet.html index b04aa430c..5cee1a97c 100644 --- a/api/graphnet.html +++ b/api/graphnet.html @@ -522,7 +522,7 @@ Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.coarsening.html b/api/graphnet.models.coarsening.html index 1dbd5b679..5341555ea 100644 --- a/api/graphnet.models.coarsening.html +++ b/api/graphnet.models.coarsening.html @@ -125,6 +125,7 @@ + @@ -365,11 +366,90 @@ + +
  • @@ -436,7 +516,30 @@ @@ -446,8 +549,125 @@
    -
    -

    coarsening

    +
    +

    coarsening

    +

    Class(es) for coarsening operations (i.e., clustering, or local pooling).

    +
    +
    +graphnet.models.coarsening.unbatch_edge_index(edge_index, batch)[source]
    +

    Splits the edge_index according to a batch vector.

    +
    +
    Parameters:
    +
      +
    • edge_index (Tensor) – The edge_index tensor. Must be ordered.

    • +
    • batch (LongTensor) – The batch vector +\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each +node to a specific example. Must be ordered.

    • +
    +
    +
    Return type:
    +

    List[Tensor]

    +
    +
    +
    +
    +
    +class graphnet.models.coarsening.Coarsening(reduce, transfer_attributes)[source]
    +

    Bases: Model

    +

    Base class for coarsening operations.

    +

    Construct Coarsening.

    +
    +
    Parameters:
    +
      +
    • reduce (str) –

    • +
    • transfer_attributes (bool) –

    • +
    +
    +
    +
    +
    +reduce_options = {'avg': (<function avg_pool>, <function avg_pool_x>), 'max': (<function max_pool>, <function max_pool_x>), 'min': (<function min_pool>, <function min_pool_x>), 'sum': (<function sum_pool>, <function sum_pool_x>)}
    +
    +
    +
    +forward(data)[source]
    +

    Perform coarsening operation.

    +
    +
    Return type:
    +

    Union[Data, Batch]

    +
    +
    Parameters:
    +

    data (Data | Batch) –

    +
    +
    +
    +
    +
    +
    +class graphnet.models.coarsening.AttributeCoarsening(attributes, reduce, transfer_attributes)[source]
    +

    Bases: Coarsening

    +

    Coarsen pulses based on specified attributes.

    +

    Construct SimpleCoarsening.

    +
    +
    Parameters:
    +
      +
    • attributes (List[str]) –

    • +
    • reduce (str) –

    • +
    • transfer_attributes (bool) –

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.coarsening.DOMCoarsening(reduce, transfer_attributes, keys)[source]
    +

    Bases: Coarsening

    +

    Coarsen pulses to DOM-level.

    +

    Cluster pulses on the same DOM.

    +
    +
    Parameters:
    +
      +
    • reduce (str) –

    • +
    • transfer_attributes (bool) –

    • +
    • keys (List[str] | None) –

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.coarsening.CustomDOMCoarsening(reduce, transfer_attributes, keys)[source]
    +

    Bases: DOMCoarsening

    +

    Coarsen pulses to DOM-level with additional attributes.

    +

    Cluster pulses on the same DOM.

    +
    +
    Parameters:
    +
      +
    • reduce (str) –

    • +
    • transfer_attributes (bool) –

    • +
    • keys (List[str] | None) –

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.coarsening.DOMAndTimeWindowCoarsening(time_window, reduce, transfer_attributes, keys=['dom_x', 'dom_y', 'dom_z', 'rde', 'pmt_area'], time_key)[source]
    +

    Bases: Coarsening

    +

    Coarsen pulses to DOM-level, with additional time-window clustering.

    +

    Cluster pulses on the same DOM within time_window.

    +
    +
    Parameters:
    +
      +
    • time_window (float) –

    • +
    • reduce (str) –

    • +
    • transfer_attributes (bool) –

    • +
    • keys (List[str]) –

    • +
    • time_key (str) –

    • +
    +
    +
    +
    @@ -497,7 +717,7 @@

    coarseningSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.components.html b/api/graphnet.models.components.html index 831dd56fa..a46e005c6 100644 --- a/api/graphnet.models.components.html +++ b/api/graphnet.models.components.html @@ -460,13 +460,31 @@
    -
    -

    components

    +
    +

    components

    +

    Components for constructing models.

    Submodules

    @@ -518,7 +536,7 @@

    componentsSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.components.layers.html b/api/graphnet.models.components.layers.html index c9568c1c1..06e348db9 100644 --- a/api/graphnet.models.components.layers.html +++ b/api/graphnet.models.components.layers.html @@ -336,11 +336,94 @@ + +

  • @@ -451,7 +534,34 @@
    @@ -461,8 +571,145 @@
    -
    -

    layers

    +
    +

    layers

    +

    Class(es) implementing layers to be used in graphnet models.

    +
    +
    +class graphnet.models.components.layers.DynEdgeConv(nn, aggr, nb_neighbors, features_subset, **kwargs)[source]
    +

    Bases: EdgeConv, LightningModule

    +

    Dynamical edge convolution layer.

    +

    Construct DynEdgeConv.

    +
    +
    Parameters:
    +
      +
    • nn (Callable) – The MLP/torch.Module to be used within the EdgeConv.

    • +
    • aggr (str, default: 'max') – Aggregation method to be used with EdgeConv.

    • +
    • nb_neighbors (int, default: 8) – Number of neighbours to be clustered after the +EdgeConv operation.

    • +
    • features_subset (Union[Sequence[int], slice, None], default: None) – Subset of features in Data.x that should be used +when dynamically performing the new graph clustering after the +EdgeConv operation. Defaults to all features.

    • +
    • **kwargs (Any) – Additional features to be passed to EdgeConv.

    • +
    +
    +
    +
    +
    +forward(x, edge_index, batch)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • x (Tensor) –

    • +
    • edge_index (Tensor | SparseTensor) –

    • +
    • batch (Tensor | None) –

    • +
    +
    +
    +
    +
    +
    +
    +class graphnet.models.components.layers.EdgeConvTito(nn, aggr, **kwargs)[source]
    +

    Bases: MessagePassing, LightningModule

    +

    Implementation of EdgeConvTito layer used in TITO solution for.

    +

    ‘IceCube - Neutrinos in Deep’ kaggle competition.

    +

    Construct EdgeConvTito.

    +
    +
    Parameters:
    +
      +
    • nn (Callable) – The MLP/torch.Module to be used within the EdgeConvTito.

    • +
    • aggr (str, default: 'max') – Aggregation method to be used with EdgeConvTito.

    • +
    • **kwargs (Any) – Additional features to be passed to EdgeConvTito.

    • +
    +
    +
    +
    +
    +reset_parameters()[source]
    +

    Reset all learnable parameters of the module.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +forward(x, edge_index)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • x (Tensor | Tuple[Tensor, Tensor]) –

    • +
    • edge_index (Tensor | SparseTensor) –

    • +
    +
    +
    +
    +
    +
    +message(x_i, x_j)[source]
    +

    Edgeconvtito message passing.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • x_i (Tensor) –

    • +
    • x_j (Tensor) –

    • +
    +
    +
    +
    +
    +
    +
    +class graphnet.models.components.layers.DynTrans(layer_sizes, aggr, features_subset, n_head, **kwargs)[source]
    +

    Bases: EdgeConvTito, LightningModule

    +

    Implementation of dynTrans1 layer used in TITO solution for.

    +

    ‘IceCube - Neutrinos in Deep’ kaggle competition.

    +

    Construct DynTrans.

    +
    +
    Parameters:
    +
      +
    • nn – The MLP/torch.Module to be used within the DynTrans.

    • +
    • layer_sizes (Optional[List[int]], default: None) – List of layer sizes to be used in DynTrans.

    • +
    • aggr (str, default: 'max') – Aggregation method to be used with DynTrans.

    • +
    • features_subset (Union[Sequence[int], slice, None], default: None) – Subset of features in Data.x that should be used +when dynamically performing the new graph clustering after the +EdgeConv operation. Defaults to all features.

    • +
    • n_head (int, default: 8) – Number of heads to be used in the multiheadattention models.

    • +
    • **kwargs (Any) – Additional features to be passed to DynTrans.

    • +
    +
    +
    +
    +
    +forward(x, edge_index, batch)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • x (Tensor) –

    • +
    • edge_index (Tensor | SparseTensor) –

    • +
    • batch (Tensor | None) –

    • +
    +
    +
    +
    +
    @@ -512,7 +759,7 @@

    layersSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.components.pool.html b/api/graphnet.models.components.pool.html index da3409715..16bcb18ac 100644 --- a/api/graphnet.models.components.pool.html +++ b/api/graphnet.models.components.pool.html @@ -125,6 +125,7 @@ + @@ -343,11 +344,106 @@ + +

  • @@ -451,7 +547,32 @@ @@ -461,8 +582,220 @@
    -
    -

    pool

    +
    +

    pool

    +

    Functions for performing pooling/clustering/coarsening.

    +
    +
    +graphnet.models.components.pool.min_pool(cluster, data, transform)[source]
    +

    Perform min-pooling of Data.

    +

    Like max_pool, just negating `data.x.

    +
    +
    Return type:
    +

    Data

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) –

    • +
    • data (Data) –

    • +
    • transform (Any | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.models.components.pool.min_pool_x(cluster, x, batch, size)[source]
    +

    Perform min-pooling of Tensor.

    +

    Like max_pool_x, just negating `x.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) –

    • +
    • x (Tensor) –

    • +
    • batch (LongTensor) –

    • +
    • size (int | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.models.components.pool.sum_pool_and_distribute(tensor, cluster_index, batch)[source]
    +

    Sum-pool values and distribute result to the individual nodes.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • tensor (Tensor) –

    • +
    • cluster_index (LongTensor) –

    • +
    • batch (LongTensor | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.models.components.pool.group_by(data, keys)[source]
    +

    Group nodes in data that have identical values of keys.

    +

    This grouping is done with in each event in case of batching. This allows +for, e.g., assigning the same index to all pulses on the same PMT or DOM in +the same event. This can be used for coarsening graphs, e.g., from pulse- +level to DOM-level by aggregating feature across each group returned by this +method.

    +
    +
    Return type:
    +

    LongTensor

    +
    +
    Parameters:
    +
      +
    • data (Data | Batch) –

    • +
    • keys (List[str]) –

    • +
    +
    +
    +

    Example

    +
    +
    Given:

    data.f1 = [1,1,2,2,2] +data.f2 = [6,7,7,7,8]

    +
    +
    Calls:

    groupby(data, [‘f1’]) -> [0, 0, 1, 1, 1] +groupby(data, [‘f2’]) -> [0, 1, 1, 1, 2] +groupby(data, [‘f1’, ‘f2’]) -> [0, 1, 2, 2, 3]

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.group_pulses_to_dom(data)[source]
    +

    Group pulses on the same DOM, using DOM and string number.

    +
    +
    Return type:
    +

    Data

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.group_pulses_to_pmt(data)[source]
    +

    Group pulses on the same PMT, using PMT, DOM, and string number.

    +
    +
    Return type:
    +

    Data

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.sum_pool_x(cluster, x, batch, size)[source]
    +

    Sum-pool node features according to the clustering defined in cluster.

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

    • +
    • x (Tensor) – Node feature matrix +\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

    • +
    • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, +B-1\}}^N\), which assigns each node to a specific example.

    • +
    • size (Optional[int], default: None) – The maximum number of clusters in a single +example. This property is useful to obtain a batch-wise dense +representation, e.g. for applying FC layers, but should only be +used if the size of the maximum number of clusters per example is +known in advance.

    • +
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.std_pool_x(cluster, x, batch, size)[source]
    +

    Std-pool node features according to the clustering defined in cluster.

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

    • +
    • x (Tensor) – Node feature matrix +\(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

    • +
    • batch (LongTensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, +B-1\}}^N\), which assigns each node to a specific example.

    • +
    • size (Optional[int], default: None) – The maximum number of clusters in a single +example. This property is useful to obtain a batch-wise dense +representation, e.g. for applying FC layers, but should only be +used if the size of the maximum number of clusters per example is +known in advance.

    • +
    +
    +
    Return type:
    +

    Tensor

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.sum_pool(cluster, data, transform)[source]
    +

    Pool and coarsen graph according to the clustering defined in cluster.

    +

    All nodes within the same cluster will be represented as one node. +Final node features are defined by the sum of features of all nodes +within the same cluster, node positions are averaged and edge indices are +defined to be the union of the edge indices of all nodes within the same +cluster.

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

    • +
    • data (Data) – Graph data object.

    • +
    • transform (Optional[Callable], default: None) – A function/transform that takes in the +coarsened and pooled torch_geometric.data.Data object and +returns a transformed version.

    • +
    +
    +
    Return type:
    +

    Data

    +
    +
    +
    +
    +
    +graphnet.models.components.pool.std_pool(cluster, data, transform)[source]
    +

    Pool and coarsen graph according to the clustering defined in cluster.

    +

    All nodes within the same cluster will be represented as one node. +Final node features are defined by the std of features of all nodes +within the same cluster, node positions are averaged and edge indices are +defined to be the union of the edge indices of all nodes within the same +cluster.

    +
    +
    Parameters:
    +
      +
    • cluster (LongTensor) – Cluster vector \(\mathbf{c} \in \{ 0, +\ldots, N - 1 \}^N\), which assigns each node to a specific cluster.

    • +
    • data (Data) – Graph data object.

    • +
    • transform (Optional[Callable], default: None) – A function/transform that takes in the +coarsened and pooled torch_geometric.data.Data object and +returns a transformed version.

    • +
    +
    +
    Return type:
    +

    Data

    +
    +
    +
    @@ -512,7 +845,7 @@

    poolSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.detector.detector.html b/api/graphnet.models.detector.detector.html index b8a0e72ee..773954d8d 100644 --- a/api/graphnet.models.detector.detector.html +++ b/api/graphnet.models.detector.detector.html @@ -343,11 +343,45 @@ + +
  • @@ -458,7 +492,20 @@
    @@ -468,8 +515,44 @@
    -
    -

    detector

    +
    +

    detector

    +

    Base detector-specific Model class(es).

    +
    +
    +class graphnet.models.detector.detector.Detector[source]
    +

    Bases: Model

    +

    Base class for all detector-specific read-ins in graphnet.

    +

    Construct Detector.

    +
    +
    +
    +
    +abstract feature_map()[source]
    +

    List of features used/assumed by inheriting Detector objects.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    +
    +forward(node_features, node_feature_names)[source]
    +

    Pre-process graph Data features and build graph adjacency.

    +
    +
    Return type:
    +

    Data

    +
    +
    Parameters:
    +
      +
    • node_features (tensor) –

    • +
    • node_feature_names (List[str]) –

    • +
    +
    +
    +
    +
    @@ -519,7 +602,7 @@

    detectorSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.detector.html b/api/graphnet.models.detector.html index 0c7904b36..d42b5269f 100644 --- a/api/graphnet.models.detector.html +++ b/api/graphnet.models.detector.html @@ -467,14 +467,27 @@
    -
    -

    detector

    +
    +

    detector

    +

    Detector-specific modules, for data ingestion and standardisation.

    Submodules

    @@ -526,7 +539,7 @@

    detector Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.detector.icecube.html b/api/graphnet.models.detector.icecube.html index 003c24f60..e37d2e2da 100644 --- a/api/graphnet.models.detector.icecube.html +++ b/api/graphnet.models.detector.icecube.html @@ -350,11 +350,96 @@ + +

  • @@ -458,7 +543,36 @@
    @@ -468,8 +582,85 @@
    -
    -

    icecube

    +
    +

    icecube

    +

    IceCube-specific Detector class(es).

    +
    +
    +class graphnet.models.detector.icecube.IceCube86[source]
    +

    Bases: Detector

    +

    Detector class for IceCube-86.

    +

    Construct Detector.

    +
    +
    +
    +
    +feature_map()[source]
    +

    Map standardization functions to each dimension of input data.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    +
    +
    +class graphnet.models.detector.icecube.IceCubeKaggle[source]
    +

    Bases: Detector

    +

    Detector class for Kaggle Competition.

    +

    Construct Detector.

    +
    +
    +
    +
    +feature_map()[source]
    +

    Map standardization functions to each dimension of input data.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    +
    +
    +class graphnet.models.detector.icecube.IceCubeDeepCore[source]
    +

    Bases: Detector

    +

    Detector class for IceCube-DeepCore.

    +

    Construct Detector.

    +
    +
    +
    +
    +feature_map()[source]
    +

    Map standardization functions to each dimension of input data.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    +
    +
    +class graphnet.models.detector.icecube.IceCubeUpgrade[source]
    +

    Bases: Detector

    +

    Detector class for IceCube-Upgrade.

    +

    Construct Detector.

    +
    +
    +
    +
    +feature_map()[source]
    +

    Map standardization functions to each dimension of input data.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    @@ -519,7 +710,7 @@

    icecubeSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.detector.prometheus.html b/api/graphnet.models.detector.prometheus.html index 9b27ebab9..0c266b36a 100644 --- a/api/graphnet.models.detector.prometheus.html +++ b/api/graphnet.models.detector.prometheus.html @@ -357,11 +357,36 @@ + +

  • @@ -458,7 +483,18 @@
    @@ -468,8 +504,28 @@
    -
    -

    prometheus

    +
    +

    prometheus

    +

    Prometheus-specific Detector class(es).

    +
    +
    +class graphnet.models.detector.prometheus.Prometheus[source]
    +

    Bases: Detector

    +

    Detector class for Prometheus prototype.

    +

    Construct Detector.

    +
    +
    +
    +
    +feature_map()[source]
    +

    Map standardization functions to each dimension.

    +
    +
    Return type:
    +

    Dict[str, Callable]

    +
    +
    +
    +
    @@ -519,7 +575,7 @@

    prometheusSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.convnet.html b/api/graphnet.models.gnn.convnet.html index 1d2d39cd2..52d6ba1ea 100644 --- a/api/graphnet.models.gnn.convnet.html +++ b/api/graphnet.models.gnn.convnet.html @@ -350,11 +350,36 @@ + +
  • @@ -472,7 +497,18 @@
    @@ -482,8 +518,42 @@
    -
    -

    convnet

    +
    +

    convnet

    +

    Implementation of the ConvNet GNN model architecture.

    +

    Author: Martin Ha Minh

    +
    +
    +class graphnet.models.gnn.convnet.ConvNet(nb_inputs, nb_outputs, nb_intermediate, dropout_ratio)[source]
    +

    Bases: GNN

    +

    ConvNet (convolutional network) model.

    +

    Construct ConvNet.

    +
    +
    Parameters:
    +
      +
    • nb_inputs (int) – Number of input features, i.e. dimension of input +layer.

    • +
    • nb_outputs (int) – Number of prediction labels, i.e. dimension of +output layer.

    • +
    • nb_intermediate (int, default: 128) – Number of nodes in intermediate layer(s).

    • +
    • dropout_ratio (float, default: 0.3) – Fraction of nodes to drop.

    • +
    +
    +
    +
    +
    +forward(data)[source]
    +

    Apply learnable forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    @@ -533,7 +603,7 @@

    convnet Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.dynedge.html b/api/graphnet.models.gnn.dynedge.html index 1ae3f0c44..908284418 100644 --- a/api/graphnet.models.gnn.dynedge.html +++ b/api/graphnet.models.gnn.dynedge.html @@ -357,11 +357,36 @@ + +

  • @@ -472,7 +497,18 @@
    @@ -482,8 +518,64 @@
    -
    -

    dynedge

    +
    +

    dynedge

    +

    Implementation of the DynEdge GNN model architecture.

    +
    +
    +class graphnet.models.gnn.dynedge.DynEdge(nb_inputs, *, nb_neighbours, features_subset, dynedge_layer_sizes, post_processing_layer_sizes, readout_layer_sizes, global_pooling_schemes, add_global_variables_after_pooling)[source]
    +

    Bases: GNN

    +

    DynEdge (dynamical edge convolutional) model.

    +

    Construct DynEdge.

    +
    +
    Parameters:
    +
      +
    • nb_inputs (int) – Number of input features on each node.

    • +
    • nb_neighbours (int, default: 8) – Number of neighbours to used in the k-nearest +neighbour clustering which is performed after each (dynamical) +edge convolution.

    • +
    • features_subset (Union[List[int], slice, None], default: None) – The subset of latent features on each node that +are used as metric dimensions when performing the k-nearest +neighbours clustering. Defaults to [0,1,2].

    • +
    • dynedge_layer_sizes (Optional[List[Tuple[int, ...]]], default: None) – The layer sizes, or latent feature dimenions, +used in the DynEdgeConv layer. Each entry in +dynedge_layer_sizes corresponds to a single DynEdgeConv +layer; the integers in the corresponding tuple corresponds to +the layer sizes in the multi-layer perceptron (MLP) that is +applied within each DynEdgeConv layer. That is, a list of +size-two tuples means that all DynEdgeConv layers contain a +two-layer MLP. +Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].

    • +
    • post_processing_layer_sizes (Optional[List[int]], default: None) – Hidden layer sizes in the MLP +following the skip-concatenation of the outputs of each +DynEdgeConv layer. Defaults to [336, 256].

    • +
    • readout_layer_sizes (Optional[List[int]], default: None) – Hidden layer sizes in the MLP following the +post-processing _and_ optional global pooling. As this is the +last layer(s) in the model, the last layer in the read-out +yields the output of the DynEdge model. Defaults to [128,].

    • +
    • global_pooling_schemes (Union[str, List[str], None], default: None) – The list global pooling schemes to use. +Options are: “min”, “max”, “mean”, and “sum”.

    • +
    • add_global_variables_after_pooling (bool, default: False) – Whether to add global variables +after global pooling. The alternative is to added (distribute) +them to the individual nodes before any convolutional +operations.

    • +
    +
    +
    +
    +
    +forward(data)[source]
    +

    Apply learnable forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    @@ -533,7 +625,7 @@

    dynedge Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.dynedge_jinst.html b/api/graphnet.models.gnn.dynedge_jinst.html index f46df1144..407d0cb08 100644 --- a/api/graphnet.models.gnn.dynedge_jinst.html +++ b/api/graphnet.models.gnn.dynedge_jinst.html @@ -364,11 +364,36 @@ + +

  • @@ -472,7 +497,18 @@
    @@ -482,8 +518,39 @@
    -
    -

    dynedge_jinst

    +
    +

    dynedge_jinst

    +

    Implementation of the exact DynEdge architecture used in [2209.03042].

    +

    Author: Rasmus Oersoe

    +
    +
    +class graphnet.models.gnn.dynedge_jinst.DynEdgeJINST(nb_inputs, layer_size_scale)[source]
    +

    Bases: GNN

    +

    DynEdge (dynamical edge convolutional) model used in [2209.03042].

    +

    Construct DynEdgeJINST.

    +
    +
    Parameters:
    +
      +
    • nb_inputs (int) – Number of input features.

    • +
    • nb_outputs – Number of output features.

    • +
    • layer_size_scale (int, default: 4) – Integer that scales the size of hidden layers.

    • +
    +
    +
    +
    +
    +forward(data)[source]
    +

    Apply learnable forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    @@ -533,7 +600,7 @@

    dynedge_jinst Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.dynedge_kaggle_tito.html b/api/graphnet.models.gnn.dynedge_kaggle_tito.html index 6f259d9bc..a27928192 100644 --- a/api/graphnet.models.gnn.dynedge_kaggle_tito.html +++ b/api/graphnet.models.gnn.dynedge_kaggle_tito.html @@ -371,11 +371,36 @@ + +

  • @@ -472,7 +497,18 @@
    @@ -482,8 +518,49 @@
    -
    -

    dynedge_kaggle_tito

    +
    +

    dynedge_kaggle_tito

    +

    Implementation of DynEdge architecture used in.

    +
    +

    IceCube - Neutrinos in Deep Ice

    +
    +

    Reconstruct the direction of neutrinos from the Universe to the South Pole

    +

    Kaggle competition.

    +

    Solution by TITO.

    +
    +
    +class graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO(nb_inputs, features_subset(0, 4, None), dyntrans_layer_sizes, global_pooling_schemes=['max'])[source]
    +

    Bases: GNN

    +

    DynEdge (dynamical edge convolutional) model.

    +

    Construct DynEdge.

    +
    +
    Parameters:
    +
      +
    • nb_inputs (int) – Number of input features on each node.

    • +
    • features_subset (slice, default: slice(0, 4, None)) – The subset of latent features on each node that +are used as metric dimensions when performing the k-nearest +neighbours clustering. Defaults to [0,1,2,3].

    • +
    • dyntrans_layer_sizes (Optional[List[Tuple[int, ...]]], default: None) – The layer sizes, or latent feature dimenions, +used in the DynTrans layer.

    • +
    • global_pooling_schemes (List[str], default: ['max']) – The list global pooling schemes to use. +Options are: “min”, “max”, “mean”, and “sum”.

    • +
    +
    +
    +
    +
    +forward(data)[source]
    +

    Apply learnable forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    @@ -533,7 +610,7 @@

    dynedge_kaggle_t

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.gnn.html b/api/graphnet.models.gnn.gnn.html index a59853501..78ee2338d 100644 --- a/api/graphnet.models.gnn.gnn.html +++ b/api/graphnet.models.gnn.gnn.html @@ -378,11 +378,54 @@ + +
  • @@ -472,7 +515,22 @@
    @@ -482,8 +540,47 @@
    -
    -

    gnn

    +
    +

    gnn

    +

    Base GNN-specific Model class(es).

    +
    +
    +class graphnet.models.gnn.gnn.GNN(nb_inputs, nb_outputs)[source]
    +

    Bases: Model

    +

    Base class for all core GNN models in graphnet.

    +

    Construct GNN.

    +
    +
    Parameters:
    +
      +
    • nb_inputs (int) –

    • +
    • nb_outputs (int) –

    • +
    +
    +
    +
    +
    +property nb_inputs: int
    +

    Return number of input features.

    +
    +
    +
    +property nb_outputs: int
    +

    Return number of output features.

    +
    +
    +
    +abstract forward(data)[source]
    +

    Apply learnable forward pass in model.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    @@ -533,7 +630,7 @@

    gnnSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.gnn.html b/api/graphnet.models.gnn.html index a98fe7164..be8f47a7f 100644 --- a/api/graphnet.models.gnn.html +++ b/api/graphnet.models.gnn.html @@ -481,16 +481,32 @@
    -
    -

    gnn

    +
    +

    gnn

    +

    GNN-specific modules, for performing the main learnable operations.

    Submodules

    @@ -542,7 +558,7 @@

    gnnSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.edges.edges.html b/api/graphnet.models.graphs.edges.edges.html index ea7739e62..98b04f64a 100644 --- a/api/graphnet.models.graphs.edges.edges.html +++ b/api/graphnet.models.graphs.edges.edges.html @@ -363,11 +363,63 @@ + + @@ -473,7 +525,24 @@
    @@ -483,8 +552,102 @@
    -
    -

    edges

    +
    +

    edges

    +

    Class(es) for building/connecting graphs.

    +
    +
    +class graphnet.models.graphs.edges.edges.EdgeDefinition(name, class_name, level, log_folder, **kwargs)[source]
    +

    Bases: Model

    +

    Base class for graph building.

    +

    Construct Logger.

    +
    +
    Parameters:
    +
      +
    • name (str | None) –

    • +
    • class_name (str | None) –

    • +
    • level (int) –

    • +
    • log_folder (str | None) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +forward(graph)[source]
    +

    Construct edges based on problem specific implementation of.

    +

    ´_construct_edges´

    +
    +
    Parameters:
    +

    graph (Data) – a graph without edges

    +
    +
    Returns:
    +

    a graph with edges

    +
    +
    Return type:
    +

    graph

    +
    +
    +
    +
    +
    +
    +class graphnet.models.graphs.edges.edges.KNNEdges(nb_nearest_neighbours, columns=[0, 1, 2])[source]
    +

    Bases: EdgeDefinition

    +

    Builds edges from the k-nearest neighbours.

    +

    K-NN Edge definition.

    +

    Will connect nodes together with their ´nb_nearest_neighbours´ +nearest neighbours in the feature space given by ´columns´.

    +
    +
    Parameters:
    +
      +
    • nb_nearest_neighbours (int) – number of neighbours.

    • +
    • columns (List[int], default: [0, 1, 2]) – Node features to use for distance calculation.

    • +
    • [0 (Defaults to) –

    • +
    • 1

    • +
    • 2].

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.graphs.edges.edges.RadialEdges(radius, columns=[0, 1, 2])[source]
    +

    Bases: EdgeDefinition

    +

    Builds graph from a sphere of chosen radius centred at each node.

    +

    Radial edges.

    +

    Connects each node to other nodes that are within a sphere of +radius ´r´ centered at the node. The feature space of ´r´ is defined +by ´columns´

    +
    +
    Parameters:
    +
      +
    • radius (float) – radius of sphere

    • +
    • columns (List[int], default: [0, 1, 2]) – columns of the node feature matrix used.

    • +
    • [0 (Defaults to) –

    • +
    • 1

    • +
    • 2].

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.graphs.edges.edges.EuclideanEdges(sigma, threshold, columns)[source]
    +

    Bases: EdgeDefinition

    +

    Builds edges according to Euclidean distance between nodes.

    +

    See https://arxiv.org/pdf/1809.06166.pdf.

    +

    Construct EuclideanEdges.

    +
    +
    Parameters:
    +
      +
    • sigma (float) –

    • +
    • threshold (float) –

    • +
    • columns (List[int]) –

    • +
    +
    +
    +
    @@ -534,7 +697,7 @@

    edgesSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.edges.html b/api/graphnet.models.graphs.edges.html index e6d65d3e9..e4feb5f18 100644 --- a/api/graphnet.models.graphs.edges.html +++ b/api/graphnet.models.graphs.edges.html @@ -482,12 +482,22 @@
    -
    -

    edges

    +
    +

    edges

    +

    Modules for constructing graphs.

    +

    ´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.

    Submodules

    @@ -539,7 +549,7 @@

    edges Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.graph_definition.html b/api/graphnet.models.graphs.graph_definition.html index 45116549f..8dcade719 100644 --- a/api/graphnet.models.graphs.graph_definition.html +++ b/api/graphnet.models.graphs.graph_definition.html @@ -371,11 +371,36 @@ + +
  • @@ -465,7 +490,18 @@
    @@ -475,8 +511,59 @@
    -
    -

    graph_definition

    +
    +

    graph_definition

    +

    Modules for defining graphs.

    +

    These are self-contained graph definitions that hold all the graph-altering +code in graphnet. These modules define what the GNNs sees as input and can be +passed to dataloaders during training and deployment.

    +
    +
    +class graphnet.models.graphs.graph_definition.GraphDefinition(detector, node_definition, edge_definition, node_feature_names, dtype)[source]
    +

    Bases: Model

    +

    An Abstract class to create graph definitions from.

    +

    Construct ´GraphDefinition´. The ´detector´ holds.

    +

    ´Detector´-specific code. E.g. scaling/standardization and geometry +tables.

    +

    ´node_definition´ defines the nodes in the graph.

    +

    ´edge_definition´ defines the connectivity of the nodes in the graph.

    +
    +
    Parameters:
    +
      +
    • detector (Detector) – The corresponding ´Detector´ representing the data.

    • +
    • node_definition (NodeDefinition) – Definition of nodes.

    • +
    • edge_definition (Optional[EdgeDefinition], default: None) – Definition of edges. Defaults to None.

    • +
    • node_feature_names (Optional[List[str]], default: None) – Names of node feature columns. Defaults to None

    • +
    • dtype (Optional[dtype], default: torch.float32) – data type used for node features. e.g. ´torch.float´

    • +
    +
    +
    +
    +
    +forward(node_features, node_feature_names, truth_dicts, custom_label_functions, loss_weight_column, loss_weight, loss_weight_default_value, data_path)[source]
    +

    Construct graph as ´Data´ object.

    +
    +
    Parameters:
    +
      +
    • node_features (ndarray) – node features for graph. Shape ´[num_nodes, d]´

    • +
    • node_feature_names (List[str]) – name of each column. Shape ´[,d]´.

    • +
    • truth_dicts (Optional[List[Dict[str, Any]]], default: None) – Dictionary containing truth labels.

    • +
    • custom_label_functions (Optional[Dict[str, Callable[..., Any]]], default: None) – Custom label functions. See https://github.com/graphnet-team/graphnet/blob/main/GETTING_STARTED.md#adding-custom-truth-labels.

    • +
    • loss_weight_column (Optional[str], default: None) – Name of column that holds loss weight. Defaults to None.

    • +
    • loss_weight (Optional[float], default: None) – Loss weight associated with event. Defaults to None.

    • +
    • loss_weight_default_value (Optional[float], default: None) – default value for loss weight. Used in instances where some events have no pre-defined loss weight. Defaults to None.

    • +
    • data_path (Optional[str], default: None) – Path to dataset data files. Defaults to None.

    • +
    +
    +
    Return type:
    +

    Data

    +
    +
    Returns:
    +

    graph

    +
    +
    +
    +
    @@ -526,7 +613,7 @@

    graph_definition

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.graphs.html b/api/graphnet.models.graphs.graphs.html index ca229295b..0d889a033 100644 --- a/api/graphnet.models.graphs.graphs.html +++ b/api/graphnet.models.graphs.graphs.html @@ -378,11 +378,25 @@ + +
  • @@ -465,7 +479,14 @@
    @@ -475,8 +496,31 @@
    -
    -

    graphs

    +
    +

    graphs

    +

    A module containing different graph representations in GraphNeT.

    +
    +
    +class graphnet.models.graphs.graphs.KNNGraph(detector, node_definition, node_feature_names, dtype, nb_nearest_neighbours, columns=[0, 1, 2])[source]
    +

    Bases: GraphDefinition

    +

    A Graph representation where Edges are drawn to nearest neighbours.

    +

    Construct k-nn graph representation.

    +
    +
    Parameters:
    +
      +
    • detector (Detector) – Detector that represents your data.

    • +
    • node_definition (NodeDefinition) – Definition of nodes in the graph.

    • +
    • node_feature_names (Optional[List[str]], default: None) – Name of node features.

    • +
    • dtype (Optional[dtype], default: torch.float32) – data type for node features.

    • +
    • nb_nearest_neighbours (int, default: 8) – Number of edges for each node. Defaults to 8.

    • +
    • columns (List[int], default: [0, 1, 2]) – node feature columns used for distance calculation

    • +
    • [0 (. Defaults to) –

    • +
    • 1

    • +
    • 2].

    • +
    +
    +
    +
    @@ -526,7 +570,7 @@

    graphsSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.html b/api/graphnet.models.graphs.html index fef0be82e..fbbc2ff3e 100644 --- a/api/graphnet.models.graphs.html +++ b/api/graphnet.models.graphs.html @@ -474,8 +474,12 @@
    -
    -

    graphs

    +
    +

    graphs

    +

    Modules for constructing graphs.

    +

    ´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.

    Subpackages

    @@ -545,7 +555,7 @@

    graphs Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.nodes.html b/api/graphnet.models.graphs.nodes.html index 1609b8099..2e99388bb 100644 --- a/api/graphnet.models.graphs.nodes.html +++ b/api/graphnet.models.graphs.nodes.html @@ -482,12 +482,20 @@
    -
    -

    nodes

    +
    +

    nodes

    +

    Modules for constructing graphs.

    +

    ´GraphDefinition´ defines the nodes and their features, and contains general +graph-manipulation.´EdgeDefinition´ defines how edges are drawn between nodes +and their features.

    Submodules

    @@ -539,7 +547,7 @@

    nodes Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.graphs.nodes.nodes.html b/api/graphnet.models.graphs.nodes.nodes.html index 23b5575bd..a5ac7ba2d 100644 --- a/api/graphnet.models.graphs.nodes.nodes.html +++ b/api/graphnet.models.graphs.nodes.nodes.html @@ -370,11 +370,63 @@ + + @@ -473,7 +525,24 @@
    @@ -483,8 +552,65 @@
    -
    -

    nodes

    +
    +

    nodes

    +

    Class(es) for building/connecting graphs.

    +
    +
    +class graphnet.models.graphs.nodes.nodes.NodeDefinition[source]
    +

    Bases: Model

    +

    Base class for graph building.

    +

    Construct Detector.

    +
    +
    +
    +
    +forward(x)[source]
    +

    Construct nodes from raw node features.

    +
    +
    Parameters:
    +
      +
    • x (tensor) – standardized node features with shape ´[num_pulses, d]´,

    • +
    • features. (where ´d´ is the number of node) –

    • +
    +
    +
    Returns:
    +

    a graph without edges

    +
    +
    Return type:
    +

    graph

    +
    +
    +
    +
    +
    +property nb_outputs: int
    +

    Return number of output features.

    +

    This the default, but may be overridden by specific inheriting classes.

    +
    +
    +
    +set_number_of_inputs(node_feature_names)[source]
    +

    Return number of inputs expected by node definition.

    +
    +
    Parameters:
    +

    node_feature_names (List[str]) – name of each node feature column.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +
    +class graphnet.models.graphs.nodes.nodes.NodesAsPulses[source]
    +

    Bases: NodeDefinition

    +

    Represent each measured pulse of Cherenkov Radiation as a node.

    +

    Construct Detector.

    +
    +
    +
    @@ -534,7 +660,7 @@

    nodesSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.html b/api/graphnet.models.html index fa38a95a6..d514e1468 100644 --- a/api/graphnet.models.html +++ b/api/graphnet.models.html @@ -445,8 +445,14 @@
    -
    -

    models

    +
    +

    models

    +

    Modules for configuring and building models.

    +

    graphnet.models allows for configuring and building complex GNN models using +simple, physics-oriented components. This module provides modular components +subclassing torch.nn.Module, meaning that users only need to import a few, +existing, purpose-built components and chain them together to form a complete +GNN

    Subpackages

    @@ -542,7 +567,7 @@

    modelsSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.model.html b/api/graphnet.models.model.html index f1bc235f8..286f23f20 100644 --- a/api/graphnet.models.model.html +++ b/api/graphnet.models.model.html @@ -372,11 +372,108 @@ + +
  • @@ -436,7 +533,34 @@
    @@ -446,8 +570,183 @@
    -
    -

    model

    +
    +

    model

    +

    Base class(es) for building models.

    +
    +
    +class graphnet.models.model.Model(name, class_name, level, log_folder, **kwargs)[source]
    +

    Bases: Logger, Configurable, LightningModule, ABC

    +

    Base class for all models in graphnet.

    +

    Construct Logger.

    +
    +
    Parameters:
    +
      +
    • name (str | None) –

    • +
    • class_name (str | None) –

    • +
    • level (int) –

    • +
    • log_folder (str | None) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +abstract forward(x)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Union[Tensor, Data]

    +
    +
    Parameters:
    +

    x (Tensor | Data) –

    +
    +
    +
    +
    +
    +fit(train_dataloader, val_dataloader, *, max_epochs, gpus, callbacks, ckpt_path, logger, log_every_n_steps, gradient_clip_val, distribution_strategy, **trainer_kwargs)[source]
    +

    Fit Model using pytorch_lightning.Trainer.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +
      +
    • train_dataloader (DataLoader) –

    • +
    • val_dataloader (DataLoader | None) –

    • +
    • max_epochs (int) –

    • +
    • gpus (List[int] | int | None) –

    • +
    • callbacks (List[Callback] | None) –

    • +
    • ckpt_path (str | None) –

    • +
    • logger (Logger | None) –

    • +
    • log_every_n_steps (int) –

    • +
    • gradient_clip_val (float | None) –

    • +
    • distribution_strategy (str | None) –

    • +
    • trainer_kwargs (Any) –

    • +
    +
    +
    +
    +
    +
    +predict(dataloader, gpus, distribution_strategy)[source]
    +

    Return predictions for dataloader.

    +

    Returns a list of Tensors, one for each model output.

    +
    +
    Return type:
    +

    List[Tensor]

    +
    +
    Parameters:
    +
      +
    • dataloader (DataLoader) –

    • +
    • gpus (List[int] | int | None) –

    • +
    • distribution_strategy (str | None) –

    • +
    +
    +
    +
    +
    +
    +predict_as_dataframe(dataloader, prediction_columns, *, additional_attributes, gpus, distribution_strategy)[source]
    +

    Return predictions for dataloader as a DataFrame.

    +

    Include additional_attributes as additional columns in the output +DataFrame.

    +
    +
    Return type:
    +

    DataFrame

    +
    +
    Parameters:
    +
      +
    • dataloader (DataLoader) –

    • +
    • prediction_columns (List[str]) –

    • +
    • additional_attributes (List[str] | None) –

    • +
    • gpus (List[int] | int | None) –

    • +
    • distribution_strategy (str | None) –

    • +
    +
    +
    +
    +
    +
    +save(path)[source]
    +

    Save entire model to path.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +

    path (str) –

    +
    +
    +
    +
    +
    +classmethod load(path)[source]
    +

    Load entire model from path.

    +
    +
    Return type:
    +

    Model

    +
    +
    Parameters:
    +

    path (str) –

    +
    +
    +
    +
    +
    +save_state_dict(path)[source]
    +

    Save model state_dict to path.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +

    path (str) –

    +
    +
    +
    +
    +
    +load_state_dict(path, **kargs)[source]
    +

    Load model state_dict from path.

    +
    +
    Return type:
    +

    Model

    +
    +
    Parameters:
    +
      +
    • path (str | Dict) –

    • +
    • kargs (Any | None) –

    • +
    +
    +
    +
    +
    +
    +classmethod from_config(source, trust, load_modules)[source]
    +

    Construct Model instance from source configuration.

    +
    +
    Parameters:
    +
      +
    • trust (bool, default: False) – Whether to trust the ModelConfig file enough to eval(…) +any lambda function expressions contained.

    • +
    • load_modules (Optional[List[str]], default: None) – List of modules used in the definition of the model +which, as a consequence, need to be loaded into the global +namespace. Defaults to loading torch.

    • +
    • source (ModelConfig | str) –

    • +
    +
    +
    Raises:
    +

    ValueError – If the ModelConfig contains lambda functions but + trust = False.

    +
    +
    Return type:
    +

    Model

    +
    +
    +
    +
    @@ -497,7 +796,7 @@

    modelSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.standard_model.html b/api/graphnet.models.standard_model.html index da580b529..b886a0725 100644 --- a/api/graphnet.models.standard_model.html +++ b/api/graphnet.models.standard_model.html @@ -379,11 +379,135 @@ + +

  • @@ -436,7 +560,40 @@ @@ -446,8 +603,193 @@
    -
    -

    standard_model

    +
    +

    standard_model

    +

    Standard model class(es).

    +
    +
    +class graphnet.models.standard_model.StandardModel(*, graph_definition, gnn, tasks, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs, scheduler_class, scheduler_kwargs, scheduler_config)[source]
    +

    Bases: Model

    +

    Main class for standard models in graphnet.

    +

    This class chains together the different elements of a complete GNN-based +model (detector read-in, GNN architecture, and task-specific read-outs).

    +

    Construct StandardModel.

    +
    +
    Parameters:
    +
      +
    • graph_definition (GraphDefinition) –

    • +
    • gnn (GNN) –

    • +
    • tasks (Task | List[Task]) –

    • +
    • optimizer_class (type) –

    • +
    • optimizer_kwargs (Dict | None) –

    • +
    • scheduler_class (type | None) –

    • +
    • scheduler_kwargs (Dict | None) –

    • +
    • scheduler_config (Dict | None) –

    • +
    +
    +
    +
    +
    +property target_labels: List[str]
    +

    Return target label.

    +
    +
    +
    +property prediction_labels: List[str]
    +

    Return prediction labels.

    +
    +
    +
    +configure_optimizers()[source]
    +

    Configure the model’s optimizer(s).

    +
    +
    Return type:
    +

    Dict[str, Any]

    +
    +
    +
    +
    +
    +forward(data)[source]
    +

    Forward pass, chaining model components.

    +
    +
    Return type:
    +

    List[Union[Tensor, Data]]

    +
    +
    Parameters:
    +

    data (Data) –

    +
    +
    +
    +
    +
    +shared_step(batch, batch_idx)[source]
    +

    Perform shared step.

    +

    Applies the forward pass and the following loss calculation, shared +between the training and validation step.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • batch (Data) –

    • +
    • batch_idx (int) –

    • +
    +
    +
    +
    +
    +
    +training_step(train_batch, batch_idx)[source]
    +

    Perform training step.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • train_batch (Data) –

    • +
    • batch_idx (int) –

    • +
    +
    +
    +
    +
    +
    +validation_step(val_batch, batch_idx)[source]
    +

    Perform validation step.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • val_batch (Data) –

    • +
    • batch_idx (int) –

    • +
    +
    +
    +
    +
    +
    +compute_loss(preds, data, verbose)[source]
    +

    Compute and sum losses across tasks.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • preds (Tensor) –

    • +
    • data (Data) –

    • +
    • verbose (bool) –

    • +
    +
    +
    +
    +
    +
    +inference()[source]
    +

    Activate inference mode.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +train(mode)[source]
    +

    Deactivate inference mode.

    +
    +
    Return type:
    +

    Model

    +
    +
    Parameters:
    +

    mode (bool) –

    +
    +
    +
    +
    +
    +predict(dataloader, gpus, distribution_strategy)[source]
    +

    Return predictions for dataloader.

    +
    +
    Return type:
    +

    List[Tensor]

    +
    +
    Parameters:
    +
      +
    • dataloader (DataLoader) –

    • +
    • gpus (List[int] | int | None) –

    • +
    • distribution_strategy (str | None) –

    • +
    +
    +
    +
    +
    +
    +predict_as_dataframe(dataloader, prediction_columns, *, additional_attributes, gpus, distribution_strategy)[source]
    +

    Return predictions for dataloader as a DataFrame.

    +

    Include additional_attributes as additional columns in the output +DataFrame.

    +
    +
    Return type:
    +

    DataFrame

    +
    +
    Parameters:
    +
      +
    • dataloader (DataLoader) –

    • +
    • prediction_columns (List[str] | None) –

    • +
    • additional_attributes (List[str] | None) –

    • +
    • gpus (List[int] | int | None) –

    • +
    • distribution_strategy (str | None) –

    • +
    +
    +
    +
    +
    @@ -497,7 +839,7 @@

    standard_modelSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.task.classification.html b/api/graphnet.models.task.classification.html index 3fb005f08..50cf9ca15 100644 --- a/api/graphnet.models.task.classification.html +++ b/api/graphnet.models.task.classification.html @@ -364,11 +364,101 @@ + +

  • @@ -458,7 +548,34 @@ @@ -468,8 +585,147 @@
    -
    -

    classification

    +
    +

    classification

    +

    Classification-specific Model class(es).

    +
    +
    +class graphnet.models.task.classification.MulticlassClassificationTask(nb_outputs, target_labels, *args, **kwargs)[source]
    +

    Bases: IdentityTask

    +

    General task for classifying any number of classes.

    +

    Requires the same number of input features as the number of classes being +predicted. Returns the untransformed latent features, which are interpreted +as the logits for each class being classified.

    +

    Construct IdentityTask.

    +

    Return the nb_outputs as a direct, affine transformation of the last +hidden layer.

    +
    +
    Parameters:
    +
      +
    • nb_outputs (int) –

    • +
    • target_labels (List[str] | Any) –

    • +
    • args (Any) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +
    +class graphnet.models.task.classification.BinaryClassificationTask(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Performs binary classification.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +default_target_labels = ['target']
    +
    +
    +
    +default_prediction_labels = ['target_pred']
    +
    +
    +
    +
    +class graphnet.models.task.classification.BinaryClassificationTaskLogits(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Performs binary classification form logits.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +default_target_labels = ['target']
    +
    +
    +
    +default_prediction_labels = ['target_pred']
    +
    +
    @@ -519,7 +775,7 @@

    classification Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.task.html b/api/graphnet.models.task.html index 1df19c8e8..0ed5f54b5 100644 --- a/api/graphnet.models.task.html +++ b/api/graphnet.models.task.html @@ -467,14 +467,38 @@
    -
    -

    task

    +
    +

    task

    +

    Physics task-specific modules to be used as model “read-outs”.

    Submodules

    @@ -526,7 +550,7 @@

    taskSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.task.reconstruction.html b/api/graphnet.models.task.reconstruction.html index 1d48b9eb0..06b772e83 100644 --- a/api/graphnet.models.task.reconstruction.html +++ b/api/graphnet.models.task.reconstruction.html @@ -371,11 +371,472 @@ + +

  • @@ -458,7 +919,132 @@ @@ -468,8 +1054,706 @@
    -
    -

    reconstruction

    +
    +

    reconstruction

    +

    Reconstruction-specific Model class(es).

    +
    +
    +class graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs azimuthal angle and associated kappa (1/var).

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['azimuth']
    +
    +
    +
    +default_prediction_labels = ['azimuth_pred', 'azimuth_kappa']
    +
    +
    +
    +nb_inputs = 2
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.AzimuthReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: AzimuthReconstructionWithKappa

    +

    Reconstructs azimuthal angle.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['azimuth']
    +
    +
    +
    +default_prediction_labels = ['azimuth_pred']
    +
    +
    +
    +nb_inputs = 2
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.DirectionReconstructionWithKappa(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs direction with kappa from the 3D-vMF distribution.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['direction']
    +
    +
    +
    +default_prediction_labels = ['dir_x_pred', 'dir_y_pred', 'dir_z_pred', 'direction_kappa']
    +
    +
    +
    +nb_inputs = 3
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.ZenithReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs zenith angle.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['zenith']
    +
    +
    +
    +default_prediction_labels = ['zenith_pred']
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.ZenithReconstructionWithKappa(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: ZenithReconstruction

    +

    Reconstructs zenith angle and associated kappa (1/var).

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['zenith']
    +
    +
    +
    +default_prediction_labels = ['zenith_pred', 'zenith_kappa']
    +
    +
    +
    +nb_inputs = 2
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.EnergyReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs energy using stable method.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['energy']
    +
    +
    +
    +default_prediction_labels = ['energy_pred']
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.EnergyReconstructionWithPower(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs energy.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['energy']
    +
    +
    +
    +default_prediction_labels = ['energy_pred']
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: EnergyReconstruction

    +

    Reconstructs energy and associated uncertainty (log(var)).

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['energy']
    +
    +
    +
    +default_prediction_labels = ['energy_pred', 'energy_sigma']
    +
    +
    +
    +nb_inputs = 2
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.VertexReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs vertex position and time.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['vertex']
    +
    +
    +
    +default_prediction_labels = ['position_x_pred', 'position_y_pred', 'position_z_pred', 'interaction_time_pred']
    +
    +
    +
    +nb_inputs = 4
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.PositionReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs vertex position.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['position']
    +
    +
    +
    +default_prediction_labels = ['position_x_pred', 'position_y_pred', 'position_z_pred']
    +
    +
    +
    +nb_inputs = 3
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.TimeReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs time.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['interaction_time']
    +
    +
    +
    +default_prediction_labels = ['interaction_time_pred']
    +
    +
    +
    +nb_inputs = 1
    +
    +
    +
    +
    +class graphnet.models.task.reconstruction.InelasticityReconstruction(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Task

    +

    Reconstructs interaction inelasticity.

    +

    That is, 1-(track energy / hadronic energy).

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +default_target_labels = ['inelasticity']
    +
    +
    +
    +default_prediction_labels = ['inelasticity_pred']
    +
    +
    +
    +nb_inputs = 1
    +
    +
    @@ -519,7 +1803,7 @@

    reconstruction Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.task.task.html b/api/graphnet.models.task.task.html index 1d38d3bae..35d5c451e 100644 --- a/api/graphnet.models.task.task.html +++ b/api/graphnet.models.task.task.html @@ -378,11 +378,128 @@ + +

  • @@ -458,7 +575,40 @@ @@ -468,8 +618,154 @@
    -
    -

    task

    +
    +

    task

    +

    Base physics task-specific Model class(es).

    +
    +
    +class graphnet.models.task.task.Task(*, hidden_size, loss_function, target_labels, prediction_labels, transform_prediction_and_target, transform_target, transform_inference, transform_support, loss_weight)[source]
    +

    Bases: Model

    +

    Base class for all reconstruction and classification tasks.

    +

    Construct Task.

    +
    +
    Parameters:
    +
      +
    • hidden_size (int) – The number of nodes in the layer feeding into this +tasks, used to construct the affine transformation to the +predicted quantity.

    • +
    • loss_function (LossFunction) – Loss function appropriate to the task.

    • +
    • target_labels (Union[str, List[str], None], default: None) – Name(s) of the quantity/-ies being predicted, used +to extract the target tensor(s) from the Data object in +.compute_loss(…).

    • +
    • prediction_labels (Union[str, List[str], None], default: None) – The name(s) of each column that is predicted by +the model during inference. If not given, the name will auto +matically be set to target_label + _pred.

    • +
    • transform_prediction_and_target (Optional[Callable], default: None) – Optional function to transform +both the predicted and target tensor before passing them to the +loss function. Useful e.g. for having the model predict +quantities on a physical scale, but transforming this scale to +O(1) for a numerically stable loss computation.

    • +
    • transform_target (Optional[Callable], default: None) – Optional function to transform only the target +tensor before passing it, and the predicted tensor, to the loss +function. Useful e.g. for having the model predict a +transformed version of the target quantity, e.g. the log10- +scaled energy, rather than the physical quantity itself. Used +in conjunction with transform_inference to perform the +inverse transform on the predicted quantity to recover the +physical scale.

    • +
    • transform_inference (Optional[Callable], default: None) – Optional function to inverse-transform the +model prediction to recover a physical scale. Used in +conjunction with transform_target.

    • +
    • transform_support (Optional[Tuple], default: None) – Optional tuple to specify minimum and maximum +of the range of validity for the inverse transforms +transform_target and transform_inference in case this is +restricted. By default the invertibility of transform_target +is tested on the range [-1e6, 1e6].

    • +
    • loss_weight (Optional[str], default: None) – Name of the attribute in data containing per-event +loss weights.

    • +
    +
    +
    +
    +
    +abstract property nb_inputs: int
    +

    Return number of inputs assumed by task.

    +
    +
    +
    +abstract property default_target_labels: List[str]
    +

    Return default target labels.

    +
    +
    +
    +abstract property default_prediction_labels: List[str]
    +

    Return default prediction labels.

    +
    +
    +
    +forward(x)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Union[Tensor, Data]

    +
    +
    Parameters:
    +

    x (Tensor | Data) –

    +
    +
    +
    +
    +
    +compute_loss(pred, data)[source]
    +

    Compute loss of pred wrt.

    +

    target labels in data.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • pred (Tensor | Data) –

    • +
    • data (Data) –

    • +
    +
    +
    +
    +
    +
    +inference()[source]
    +

    Activate inference mode.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +train_eval()[source]
    +

    Deactivate inference mode.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +
    +class graphnet.models.task.task.IdentityTask(nb_outputs, target_labels, *args, **kwargs)[source]
    +

    Bases: Task

    +

    Identity, or trivial, task.

    +

    Construct IdentityTask.

    +

    Return the nb_outputs as a direct, affine transformation of the last +hidden layer.

    +
    +
    Parameters:
    +
      +
    • nb_outputs (int) –

    • +
    • target_labels (List[str] | Any) –

    • +
    • args (Any) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +property default_target_labels: List[str]
    +

    Return default target labels.

    +
    +
    +
    +property default_prediction_labels: List[str]
    +

    Return default prediction labels.

    +
    +
    +
    +property nb_inputs: int
    +

    Return number of inputs assumed by task.

    +
    +
    @@ -519,7 +815,7 @@

    task Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.models.utils.html b/api/graphnet.models.utils.html index 6fa20aae4..475f25794 100644 --- a/api/graphnet.models.utils.html +++ b/api/graphnet.models.utils.html @@ -386,11 +386,43 @@ + + @@ -436,7 +468,18 @@ @@ -446,8 +489,69 @@
    -
    -

    utils

    +
    +

    utils

    +

    Utility functions for graphnet.models.

    +
    +
    +graphnet.models.utils.calculate_xyzt_homophily(x, edge_index, batch)[source]
    +

    Calculate xyzt-homophily from a batch of graphs.

    +

    Homophily is a graph scalar quantity that measures the likeness of +variables in nodes. Notice that this calculator assumes a special order of +input features in x.

    +
    +
    Return type:
    +

    Tuple[Tensor, Tensor, Tensor, Tensor]

    +
    +
    Returns:
    +

    Tuple, each element with shape [batch_size,1].

    +
    +
    Parameters:
    +
      +
    • x (Tensor) –

    • +
    • edge_index (LongTensor) –

    • +
    • batch (Batch) –

    • +
    +
    +
    +
    +
    +
    +graphnet.models.utils.calculate_distance_matrix(xyz_coords)[source]
    +

    Calculate the matrix of pairwise distances between pulses.

    +
    +
    Parameters:
    +

    xyz_coords (Tensor) – (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Returns:
    +

    Matrix of pairwise distances, of shape [nb_doms, nb_doms]

    +
    +
    +
    +
    +
    +graphnet.models.utils.knn_graph_batch(batch, k, columns)[source]
    +

    Calculate k-nearest-neighbours with individual k for each batch event.

    +
    +
    Parameters:
    +
      +
    • batch (Batch) – Batch of events.

    • +
    • k (List[int]) – A list of k’s.

    • +
    • columns (List[int]) – The columns of Data.x used for computing the distances. E.g., +Data.x[:,[0,1,2]]

    • +
    +
    +
    Return type:
    +

    Batch

    +
    +
    Returns:
    +

    Returns the same batch of events, but with updated edges.

    +
    +
    +
    @@ -497,7 +601,7 @@

    utilsSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.pisa.fitting.html b/api/graphnet.pisa.fitting.html index 71412f39f..8bc401462 100644 --- a/api/graphnet.pisa.fitting.html +++ b/api/graphnet.pisa.fitting.html @@ -659,7 +659,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.pisa.html b/api/graphnet.pisa.html index 610103161..77788eee8 100644 --- a/api/graphnet.pisa.html +++ b/api/graphnet.pisa.html @@ -465,7 +465,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.pisa.plotting.html b/api/graphnet.pisa.plotting.html index 07ab22cb9..fd463920f 100644 --- a/api/graphnet.pisa.plotting.html +++ b/api/graphnet.pisa.plotting.html @@ -573,7 +573,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.callbacks.html b/api/graphnet.training.callbacks.html index 62f336aab..2d3ffa1f3 100644 --- a/api/graphnet.training.callbacks.html +++ b/api/graphnet.training.callbacks.html @@ -344,11 +344,110 @@ + +
  • @@ -408,7 +507,36 @@ @@ -418,8 +546,151 @@
    -
    -

    callbacks

    +
    +

    callbacks

    +

    Callback class(es) for using during model training.

    +
    +
    +class graphnet.training.callbacks.PiecewiseLinearLR(optimizer, milestones, factors, last_epoch, verbose)[source]
    +

    Bases: _LRScheduler

    +

    Interpolate learning rate linearly between milestones.

    +

    Construct PiecewiseLinearLR.

    +

    For each milestone, denoting a specified number of steps, a factor +multiplying the base learning rate is specified. For steps between two +milestones, the learning rate is interpolated linearly between the two +closest milestones. For steps before the first milestone, the factor +for the first milestone is used; vice versa for steps after the last +milestone.

    +
    +
    Parameters:
    +
      +
    • optimizer (Optimizer) – Wrapped optimizer.

    • +
    • milestones (List[int]) – List of step indices. Must be increasing.

    • +
    • factors (List[float]) – List of multiplicative factors. Must be same length as +milestones.

    • +
    • last_epoch (int, default: -1) – The index of the last epoch.

    • +
    • verbose (bool, default: False) – If True, prints a message to stdout for each update.

    • +
    +
    +
    +
    +
    +get_lr()[source]
    +

    Get effective learning rate(s) for each optimizer.

    +
    +
    Return type:
    +

    List[float]

    +
    +
    +
    +
    +
    +
    +class graphnet.training.callbacks.ProgressBar(refresh_rate, process_position)[source]
    +

    Bases: TQDMProgressBar

    +

    Custom progress bar for graphnet.

    +

    Customises the default progress in pytorch-lightning.

    +
    +
    Parameters:
    +
      +
    • refresh_rate (int) –

    • +
    • process_position (int) –

    • +
    +
    +
    +
    +
    +init_validation_tqdm()[source]
    +

    Override for customisation.

    +
    +
    Return type:
    +

    Bar

    +
    +
    +
    +
    +
    +init_predict_tqdm()[source]
    +

    Override for customisation.

    +
    +
    Return type:
    +

    Bar

    +
    +
    +
    +
    +
    +init_test_tqdm()[source]
    +

    Override for customisation.

    +
    +
    Return type:
    +

    Bar

    +
    +
    +
    +
    +
    +init_train_tqdm()[source]
    +

    Override for customisation.

    +
    +
    Return type:
    +

    Bar

    +
    +
    +
    +
    +
    +get_metrics(trainer, model)[source]
    +

    Override to not show the version number in the logging.

    +
    +
    Return type:
    +

    Dict

    +
    +
    Parameters:
    +
      +
    • trainer (Trainer) –

    • +
    • model (LightningModule) –

    • +
    +
    +
    +
    +
    +
    +on_train_epoch_start(trainer, model)[source]
    +

    Print the results of the previous epoch on a separate line.

    +

    This allows the user to see the losses/metrics for previous epochs +while the current is training. The default behaviour in pytorch- +lightning is to overwrite the progress bar from previous epochs.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +
      +
    • trainer (Trainer) –

    • +
    • model (LightningModule) –

    • +
    +
    +
    +
    +
    +
    +on_train_epoch_end(trainer, model)[source]
    +

    Log the final progress bar for the epoch to file.

    +

    Don’t duplciate to stdout.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +
      +
    • trainer (Trainer) –

    • +
    • model (LightningModule) –

    • +
    +
    +
    +
    +
    @@ -469,7 +740,7 @@

    callbacksSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.html b/api/graphnet.training.html index 3a7e0dc99..c9ac2b308 100644 --- a/api/graphnet.training.html +++ b/api/graphnet.training.html @@ -424,10 +424,38 @@

    Submodules

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.labels.html b/api/graphnet.training.labels.html index cb464052f..6514d41a8 100644 --- a/api/graphnet.training.labels.html +++ b/api/graphnet.training.labels.html @@ -351,11 +351,45 @@ + +
  • @@ -408,7 +442,20 @@
    @@ -418,8 +465,48 @@
    -
    -

    labels

    +
    +

    labels

    +

    Class(es) for constructing training labels at runtime.

    +
    +
    +class graphnet.training.labels.Label(key)[source]
    +

    Bases: ABC, Logger

    +

    Base Label class for producing labels from single Data instance.

    +

    Construct Label.

    +
    +
    Parameters:
    +

    key (str) – The name of the field in Data where the label will be +stored. That is, graph[key] = label.

    +
    +
    +
    +
    +property key: str
    +

    Return value of key.

    +
    +
    +
    +
    +class graphnet.training.labels.Direction(key, azimuth_key, zenith_key)[source]
    +

    Bases: Label

    +

    Class for producing particle direction/pointing label.

    +

    Construct Direction.

    +
    +
    Parameters:
    +
      +
    • key (str, default: 'direction') – The name of the field in Data where the label will be +stored. That is, graph[key] = label.

    • +
    • azimuth_key (str, default: 'azimuth') – The name of the pre-existing key in graph that will +be used to access the azimiuth angle, used when calculating +the direction.

    • +
    • zenith_key (str, default: 'zenith') – The name of the pre-existing key in graph that will +be used to access the zenith angle, used when calculating the +direction.

    • +
    +
    +
    +
    @@ -469,7 +556,7 @@

    labels Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.loss_functions.html b/api/graphnet.training.loss_functions.html index 61edd4160..feb15b1bf 100644 --- a/api/graphnet.training.loss_functions.html +++ b/api/graphnet.training.loss_functions.html @@ -358,11 +358,175 @@ + +

  • @@ -408,7 +572,52 @@ @@ -418,8 +627,281 @@
    -
    -

    loss_functions

    +
    +

    loss_functions

    +

    Collection of loss functions.

    +

    All loss functions inherit from LossFunction which ensures a common syntax, +handles per-event weights, etc.

    +
    +
    +class graphnet.training.loss_functions.LossFunction(**kwargs)[source]
    +

    Bases: Model

    +

    Base class for loss functions in graphnet.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +forward(prediction, target, weights, return_elements)[source]
    +

    Forward pass for all loss functions.

    +
    +
    Parameters:
    +
      +
    • prediction (Tensor) – Tensor containing predictions. Shape [N,P]

    • +
    • target (Tensor) – Tensor containing targets. Shape [N,T]

    • +
    • return_elements (bool, default: False) – Whether elementwise loss terms should be returned. +The alternative is to return the averaged loss across examples.

    • +
    • weights (Tensor | None) –

    • +
    +
    +
    Return type:
    +

    Tensor

    +
    +
    Returns:
    +

    Loss, either averaged to a scalar (if return_elements = False) or +elementwise terms with shape [N,] (if return_elements = True).

    +
    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.MSELoss(**kwargs)[source]
    +

    Bases: LossFunction

    +

    Mean squared error loss.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.RMSELoss(**kwargs)[source]
    +

    Bases: MSELoss

    +

    Root mean squared error loss.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.LogCoshLoss(**kwargs)[source]
    +

    Bases: LossFunction

    +

    Log-cosh loss function.

    +

    Acts like x^2 for small x; and like |x| for large x.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.CrossEntropyLoss(options, *args, **kwargs)[source]
    +

    Bases: LossFunction

    +

    Compute cross-entropy loss for classification tasks.

    +

    Predictions are an [N, num_class]-matrix of logits (i.e., non-softmax’ed +probabilities), and targets are an [N,1]-matrix with integer values in +(0, num_classes - 1).

    +

    Construct CrossEntropyLoss.

    +
    +
    Parameters:
    +
      +
    • options (int | List[Any] | Dict[Any, int]) –

    • +
    • args (Any) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.BinaryCrossEntropyLoss(**kwargs)[source]
    +

    Bases: LossFunction

    +

    Compute binary cross entropy loss.

    +

    Predictions are vector probabilities (i.e., values between 0 and 1), and +targets should be 0 and 1.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.LogCMK(*args, **kwargs)[source]
    +

    Bases: Function

    +

    MIT License.

    +

    Copyright (c) 2019 Max Ryabinin

    +

    Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the “Software”), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions:

    +

    The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software.

    +

    THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +_____________________

    +

    From [https://github.com/mryab/vmf_loss/blob/master/losses.py] +Modified to use modified Bessel function instead of exponentially scaled ditto +(i.e. .ive -> .iv) as indiciated in [1812.04616] in spite of suggestion in +Sec. 8.2 of this paper. The change has been validated through comparison with +exact calculations for m=2 and m=3 and found to yield the correct results.

    +
    +
    +static forward(ctx, m, kappa)[source]
    +

    Forward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • ctx (Any) –

    • +
    • m (int) –

    • +
    • kappa (Tensor) –

    • +
    +
    +
    +
    +
    +
    +static backward(ctx, grad_output)[source]
    +

    Backward pass.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • ctx (Any) –

    • +
    • grad_output (Tensor) –

    • +
    +
    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.VonMisesFisherLoss(**kwargs)[source]
    +

    Bases: LossFunction

    +

    General class for calculating von Mises-Fisher loss.

    +

    Requires implementation for specific dimension m in which the target and +prediction vectors need to be prepared.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +classmethod log_cmk_exact(m, kappa)[source]
    +

    Calculate $log C_{m}(k)$ term in von Mises-Fisher loss exactly.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • m (int) –

    • +
    • kappa (Tensor) –

    • +
    +
    +
    +
    +
    +
    +classmethod log_cmk_approx(m, kappa)[source]
    +

    Calculate $log C_{m}(k)$ term in von Mises-Fisher loss approx.

    +

    [https://arxiv.org/abs/1812.04616] Sec. 8.2 with additional minus sign.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • m (int) –

    • +
    • kappa (Tensor) –

    • +
    +
    +
    +
    +
    +
    +classmethod log_cmk(m, kappa, kappa_switch)[source]
    +

    Calculate $log C_{m}(k)$ term in von Mises-Fisher loss.

    +

    Since log_cmk_exact is diverges for kappa >~ 700 (using float64 +precision), and since log_cmk_approx is unaccurate for small kappa, +this method automatically switches between the two at kappa_switch, +ensuring continuity at this point.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +
      +
    • m (int) –

    • +
    • kappa (Tensor) –

    • +
    • kappa_switch (float) –

    • +
    +
    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.VonMisesFisher2DLoss(**kwargs)[source]
    +

    Bases: VonMisesFisherLoss

    +

    von Mises-Fisher loss function vectors in the 2D plane.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.EuclideanDistanceLoss(**kwargs)[source]
    +

    Bases: LossFunction

    +

    Mean squared error in three dimensions.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    +
    +
    +class graphnet.training.loss_functions.VonMisesFisher3DLoss(**kwargs)[source]
    +

    Bases: VonMisesFisherLoss

    +

    von Mises-Fisher loss function vectors in the 3D plane.

    +

    Construct LossFunction, saving model config.

    +
    +
    Parameters:
    +

    kwargs (Any) –

    +
    +
    +
    @@ -469,7 +951,7 @@

    loss_functions Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.utils.html b/api/graphnet.training.utils.html index d29636c35..ac1d9919d 100644 --- a/api/graphnet.training.utils.html +++ b/api/graphnet.training.utils.html @@ -365,11 +365,61 @@ + +

  • @@ -408,7 +458,22 @@ @@ -418,8 +483,128 @@
    -
    -

    utils

    +
    +

    utils

    +

    Utility functions for graphnet.training.

    +
    +
    +graphnet.training.utils.collate_fn(graphs)[source]
    +

    Remove graphs with less than two DOM hits.

    +

    Should not occur in “production.

    +
    +
    Return type:
    +

    Batch

    +
    +
    Parameters:
    +

    graphs (List[Data]) –

    +
    +
    +
    +
    +
    +graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]
    +

    Construct DataLoader instance.

    +
    +
    Return type:
    +

    DataLoader

    +
    +
    Parameters:
    +
      +
    • db (str) –

    • +
    • pulsemaps (str | List[str]) –

    • +
    • graph_definition (GraphDefinition | None) –

    • +
    • features (List[str]) –

    • +
    • truth (List[str]) –

    • +
    • batch_size (int) –

    • +
    • shuffle (bool) –

    • +
    • selection (List[int] | None) –

    • +
    • num_workers (int) –

    • +
    • persistent_workers (bool) –

    • +
    • node_truth (List[str] | None) –

    • +
    • truth_table (str) –

    • +
    • node_truth_table (str | None) –

    • +
    • string_selection (List[int] | None) –

    • +
    • loss_weight_table (str | None) –

    • +
    • loss_weight_column (str | None) –

    • +
    • index_column (str) –

    • +
    • labels (Dict[str, Callable] | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]
    +

    Construct train and test DataLoader instances.

    +
    +
    Return type:
    +

    Tuple[DataLoader, DataLoader]

    +
    +
    Parameters:
    +
      +
    • db (str) –

    • +
    • graph_definition (GraphDefinition | None) –

    • +
    • selection (List[int] | None) –

    • +
    • pulsemaps (str | List[str]) –

    • +
    • features (List[str]) –

    • +
    • truth (List[str]) –

    • +
    • batch_size (int) –

    • +
    • database_indices (List[int] | None) –

    • +
    • seed (int) –

    • +
    • test_size (float) –

    • +
    • num_workers (int) –

    • +
    • persistent_workers (bool) –

    • +
    • node_truth (str | None) –

    • +
    • truth_table (str) –

    • +
    • node_truth_table (str | None) –

    • +
    • string_selection (List[int] | None) –

    • +
    • loss_weight_column (str | None) –

    • +
    • loss_weight_table (str | None) –

    • +
    • index_column (str) –

    • +
    • labels (Dict[str, Callable] | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]
    +

    Get model predictions on dataloader.

    +
    +
    Return type:
    +

    DataFrame

    +
    +
    Parameters:
    +
      +
    • trainer (Trainer) –

    • +
    • model (Model) –

    • +
    • dataloader (DataLoader) –

    • +
    • prediction_columns (List[str]) –

    • +
    • node_level (bool) –

    • +
    • additional_attributes (List[str] | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.training.utils.save_results(db, tag, results, archive, model)[source]
    +

    Save trained model and prediction results in db.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +
      +
    • db (str) –

    • +
    • tag (str) –

    • +
    • results (DataFrame) –

    • +
    • archive (str) –

    • +
    • model (Model) –

    • +
    +
    +
    +
    @@ -469,7 +654,7 @@

    utils Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.training.weight_fitting.html b/api/graphnet.training.weight_fitting.html index aa7cedd58..78198cb33 100644 --- a/api/graphnet.training.weight_fitting.html +++ b/api/graphnet.training.weight_fitting.html @@ -611,7 +611,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.argparse.html b/api/graphnet.utilities.argparse.html index 60f4ad179..814d2bdf2 100644 --- a/api/graphnet.utilities.argparse.html +++ b/api/graphnet.utilities.argparse.html @@ -639,7 +639,7 @@
  • Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.base_config.html b/api/graphnet.utilities.config.base_config.html index dda866be3..48747c41d 100644 --- a/api/graphnet.utilities.config.base_config.html +++ b/api/graphnet.utilities.config.base_config.html @@ -357,11 +357,81 @@ + +
  • @@ -465,7 +535,28 @@
    @@ -475,8 +566,88 @@
    -
    -

    base_config

    +
    +

    base_config

    +

    Base config class(es).

    +
    +
    +class graphnet.utilities.config.base_config.BaseConfig[source]
    +

    Bases: BaseModel

    +

    Base class for Configs.

    +

    Create a new model by parsing and validating input data from keyword arguments.

    +

    Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be +validated to form a valid model.

    +

    __init__ uses __pydantic_self__ instead of the more common self for the first arg to +allow self as a field name.

    +
    +
    +
    +
    +classmethod load(path)[source]
    +

    Load BaseConfig from path.

    +
    +
    Return type:
    +

    BaseConfig

    +
    +
    Parameters:
    +

    path (str) –

    +
    +
    +
    +
    +
    +dump(path)[source]
    +

    Save BaseConfig to path as YAML file, or return as string.

    +
    +
    Return type:
    +

    Optional[str]

    +
    +
    Parameters:
    +

    path (str | None) –

    +
    +
    +
    +
    +
    +as_dict()[source]
    +

    Represent BaseConfig as a dict.

    +

    This builds on BaseModel.dict() but can be overwritten.

    +
    +
    Return type:
    +

    Dict[str, Dict[str, Any]]

    +
    +
    +
    +
    +
    +model_config: ClassVar[ConfigDict] = {}
    +

    Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

    +
    +
    +
    +model_fields: ClassVar[dict[str, FieldInfo]] = {}
    +

    Metadata about the fields defined on the model, +mapping of field names to [FieldInfo][pydantic.fields.FieldInfo].

    +

    This replaces Model.__fields__ from Pydantic V1.

    +
    +
    +
    +
    +graphnet.utilities.config.base_config.get_all_argument_values(fn, *args, **kwargs)[source]
    +

    Return dict of all argument values to fn, including defaults.

    +
    +
    Return type:
    +

    Dict[str, Any]

    +
    +
    Parameters:
    +
      +
    • fn (Callable) –

    • +
    • args (Any) –

    • +
    • kwargs (Any) –

    • +
    +
    +
    +
    @@ -526,7 +697,7 @@

    base_config Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.configurable.html b/api/graphnet.utilities.config.configurable.html index cc82e52e3..7654534eb 100644 --- a/api/graphnet.utilities.config.configurable.html +++ b/api/graphnet.utilities.config.configurable.html @@ -364,11 +364,54 @@ + +

  • @@ -465,7 +508,22 @@
    @@ -475,8 +533,49 @@
    -
    -

    configurable

    +
    +

    configurable

    +

    Bases for all configurable classes in graphnet.

    +
    +
    +class graphnet.utilities.config.configurable.Configurable[source]
    +

    Bases: ABC

    +

    Base class for all configurable classes in graphnet.

    +

    Construct Configurable.

    +
    +
    +
    +
    +property config: BaseConfig
    +

    Return configuration to re-create the instance.

    +
    +
    +
    +save_config(path)[source]
    +

    Save Config to path as YAML file.

    +
    +
    Return type:
    +

    None

    +
    +
    Parameters:
    +

    path (str) –

    +
    +
    +
    +
    +
    +abstract classmethod from_config(source)[source]
    +

    Construct instance from source configuration.

    +
    +
    Return type:
    +

    Any

    +
    +
    Parameters:
    +

    source (BaseConfig | str) –

    +
    +
    +
    +
    @@ -526,7 +625,7 @@

    configurable Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.dataset_config.html b/api/graphnet.utilities.config.dataset_config.html index 27d11fb12..d3b202a27 100644 --- a/api/graphnet.utilities.config.dataset_config.html +++ b/api/graphnet.utilities.config.dataset_config.html @@ -371,11 +371,198 @@ + +

  • @@ -465,7 +652,54 @@ @@ -475,8 +709,206 @@
    -
    -

    dataset_config

    +
    +

    dataset_config

    +

    Config classes for the graphnet.data.dataset module.

    +
    +
    +class graphnet.utilities.config.dataset_config.DatasetConfig(*, path, pulsemaps, features, truth, node_truth, index_column, truth_table, node_truth_table, string_selection, selection, loss_weight_table, loss_weight_column, loss_weight_default_value, seed, graph_definition)[source]
    +

    Bases: BaseConfig

    +

    Configuration for all `Dataset`s.

    +

    Construct DataConfig.

    +

    Can be used for dataset configuration as code, thereby making dataset +construction more transparent and reproducible.

    +

    Examples

    +

    In one session, do:

    +
    >>> dataset = Dataset(...)
    +>>> dataset.config.dump()
    +path: (...)
    +pulsemaps:
    +    - (...)
    +(...)
    +>>> dataset.config.dump("dataset.yml")
    +
    +
    +

    In another session, you can then do: +>>> dataset = Dataset.from_config(“dataset.yml”)

    +

    # Uniquely for DatasetConfig, you can also define and load +# multiple datasets +>>> dataset.config.selection = {

    +
    +

    “train”: “event_no % 2 == 0”, +“test”: “event_no % 2 == 1”,

    +
    +

    } +>>> dataset.config.dump(“dataset.yml”) +>>> datasets: Dict[str, Dataset] = Dataset.from_config(

    +
    +

    “dataset.yml”

    +
    +

    ) +>>> datasets +{

    +
    +

    “train”: Dataset(…), +“test”: Dataset(…),

    +
    +

    }

    +

    # You can also combine multiple selections into a single, named +# dataset +>>> dataset.config.selection = {

    +
    +
    +
    “train”: [

    “event_no % 2 == 0 & abs(pid) == 12”, +“event_no % 2 == 0 & abs(pid) == 14”, +“event_no % 2 == 0 & abs(pid) == 16”,

    +
    +
    +

    ], +(…)

    +
    +

    } +>>> dataset.config.dump(“dataset.yml”) +>>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config(

    +
    +

    “dataset.yml”

    +
    +

    ) +>>> datasets +{

    +
    +

    “train”: EnsembleDataset(…), +(…)

    +
    +

    }

    +

    # Finally, you can still reference existing selection files in CSV +# or JSON formats: +>>> dataset.config.selection = {

    +
    +

    “train”: “50000 random events ~ train_selection.csv”, +“test”: “test_selection.csv”,

    +
    +

    }

    +
    +
    Parameters:
    +
      +
    • path (str | List[str]) –

    • +
    • pulsemaps (str | List[str]) –

    • +
    • features (List[str]) –

    • +
    • truth (List[str]) –

    • +
    • node_truth (List[str] | None) –

    • +
    • index_column (str) –

    • +
    • truth_table (str) –

    • +
    • node_truth_table (str | None) –

    • +
    • string_selection (List[int] | None) –

    • +
    • selection (str | List[str] | List[int | List[int]] | Dict[str, str | List[str]] | None) –

    • +
    • loss_weight_table (str | None) –

    • +
    • loss_weight_column (str | None) –

    • +
    • loss_weight_default_value (float | None) –

    • +
    • seed (int | None) –

    • +
    • graph_definition (Any) –

    • +
    +
    +
    +
    +
    +path: Union[str, List[str]]
    +
    +
    +
    +pulsemaps: Union[str, List[str]]
    +
    +
    +
    +features: List[str]
    +
    +
    +
    +truth: List[str]
    +
    +
    +
    +node_truth: Optional[List[str]]
    +
    +
    +
    +index_column: str
    +
    +
    +
    +truth_table: str
    +
    +
    +
    +node_truth_table: Optional[str]
    +
    +
    +
    +string_selection: Optional[List[int]]
    +
    +
    +
    +selection: Union[str, List[str], List[Union[int, List[int]]], Dict[str, Union[str, List[str]]], None]
    +
    +
    +
    +loss_weight_table: Optional[str]
    +
    +
    +
    +loss_weight_column: Optional[str]
    +
    +
    +
    +loss_weight_default_value: Optional[float]
    +
    +
    +
    +seed: Optional[int]
    +
    +
    +
    +graph_definition: Any
    +
    +
    +
    +as_dict()[source]
    +

    Represent ModelConfig as a dict.

    +

    This builds on BaseModel.dict() but wraps the output in a single-key +dictionary to make it unambiguous to identify model arguments that are +themselves models.

    +
    +
    Return type:
    +

    Dict[str, Dict[str, Any]]

    +
    +
    +
    +
    +
    +model_config: ClassVar[ConfigDict] = {}
    +

    Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

    +
    +
    +
    +model_fields: ClassVar[dict[str, FieldInfo]] = {'features': FieldInfo(annotation=List[str], required=True), 'graph_definition': FieldInfo(annotation=Any, required=False), 'index_column': FieldInfo(annotation=str, required=False, default='event_no'), 'loss_weight_column': FieldInfo(annotation=Union[str, NoneType], required=False), 'loss_weight_default_value': FieldInfo(annotation=Union[float, NoneType], required=False), 'loss_weight_table': FieldInfo(annotation=Union[str, NoneType], required=False), 'node_truth': FieldInfo(annotation=Union[List[str], NoneType], required=False), 'node_truth_table': FieldInfo(annotation=Union[str, NoneType], required=False), 'path': FieldInfo(annotation=Union[str, List[str]], required=True), 'pulsemaps': FieldInfo(annotation=Union[str, List[str]], required=True), 'seed': FieldInfo(annotation=Union[int, NoneType], required=False), 'selection': FieldInfo(annotation=Union[str, List[str], List[Union[int, List[int]]], Dict[str, Union[str, List[str]]], NoneType], required=False), 'string_selection': FieldInfo(annotation=Union[List[int], NoneType], required=False), 'truth': FieldInfo(annotation=List[str], required=True), 'truth_table': FieldInfo(annotation=str, required=False, default='truth')}
    +

    Metadata about the fields defined on the model, +mapping of field names to [FieldInfo][pydantic.fields.FieldInfo].

    +

    This replaces Model.__fields__ from Pydantic V1.

    +
    +
    +
    +
    +graphnet.utilities.config.dataset_config.save_dataset_config(init_fn)[source]
    +

    Save the arguments to __init__ functions as member DatasetConfig.

    +
    +
    Return type:
    +

    Callable

    +
    +
    Parameters:
    +

    init_fn (Callable) –

    +
    +
    +
    @@ -526,7 +958,7 @@

    dataset_config<

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.html b/api/graphnet.utilities.config.html index bba4a6957..59a759d23 100644 --- a/api/graphnet.utilities.config.html +++ b/api/graphnet.utilities.config.html @@ -474,17 +474,44 @@
    -
    -

    config

    +
    +

    config

    +

    Modules for configuration files for use across graphnet.

    Submodules

    @@ -536,7 +563,7 @@

    config Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.model_config.html b/api/graphnet.utilities.config.model_config.html index 5c6fa0fbb..52531fc50 100644 --- a/api/graphnet.utilities.config.model_config.html +++ b/api/graphnet.utilities.config.model_config.html @@ -378,11 +378,81 @@ + +

  • @@ -465,7 +535,28 @@ @@ -475,8 +566,88 @@
    -
    -

    model_config

    +
    +

    model_config

    +

    Config classes for the graphnet.models module.

    +
    +
    +class graphnet.utilities.config.model_config.ModelConfig(*, class_name, arguments)[source]
    +

    Bases: BaseConfig

    +

    Configuration for all `Model`s.

    +

    Construct ModelConfig.

    +

    Can be used for model configuration as code, thereby making model +construction more transparent and reproducible. Note that this does +not save any trainable weights, meaning this is only a configuration +for the model’s hyperparameters. Any model instantiated from a +ModelConfig or file will be randomly initialised, and thus should be +trained.

    +

    Examples

    +

    In one session, do:

    +
    >>> model = Model(...)
    +>>> model.config.dump()
    +arguments:
    +    - (...): (...)
    +class_name: Model
    +>>> model.config.dump("model.yml")
    +
    +
    +

    In another session, you can then do: +>>> model = Model.from_config(“model.yml”)

    +
    +
    Parameters:
    +
      +
    • class_name (str) –

    • +
    • arguments (Dict[str, Any]) –

    • +
    +
    +
    +
    +
    +class_name: str
    +
    +
    +
    +arguments: Dict[str, Any]
    +
    +
    +
    +as_dict()[source]
    +

    Represent ModelConfig as a dict.

    +

    This builds on BaseModel.dict() but wraps the output in a single-key +dictionary to make it unambiguous to identify model arguments that are +themselves models.

    +
    +
    Return type:
    +

    Dict[str, Dict[str, Any]]

    +
    +
    +
    +
    +
    +model_config: ClassVar[ConfigDict] = {}
    +

    Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

    +
    +
    +
    +model_fields: ClassVar[dict[str, FieldInfo]] = {'arguments': FieldInfo(annotation=Dict[str, Any], required=True), 'class_name': FieldInfo(annotation=str, required=True)}
    +

    Metadata about the fields defined on the model, +mapping of field names to [FieldInfo][pydantic.fields.FieldInfo].

    +

    This replaces Model.__fields__ from Pydantic V1.

    +
    +
    +
    +
    +graphnet.utilities.config.model_config.save_model_config(init_fn)[source]
    +

    Save the arguments to __init__ functions as a member ModelConfig.

    +
    +
    Return type:
    +

    Callable

    +
    +
    Parameters:
    +

    init_fn (Callable) –

    +
    +
    +
    @@ -526,7 +697,7 @@

    model_config Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.parsing.html b/api/graphnet.utilities.config.parsing.html index 210d5f809..0e932473f 100644 --- a/api/graphnet.utilities.config.parsing.html +++ b/api/graphnet.utilities.config.parsing.html @@ -385,11 +385,70 @@ + +

  • @@ -465,7 +524,24 @@ @@ -475,8 +551,91 @@
    -
    -

    parsing

    +
    +

    parsing

    +

    Utility functions for parsing for using with Config-classes.

    +
    +
    +graphnet.utilities.config.parsing.traverse_and_apply(obj, fn, fn_kwargs)[source]
    +

    Apply fn to all elements in obj, resulting in same structure.

    +
    +
    Return type:
    +

    Any

    +
    +
    Parameters:
    +
      +
    • obj (Any) –

    • +
    • fn (Callable) –

    • +
    • fn_kwargs (Dict[str, Any] | None) –

    • +
    +
    +
    +
    +
    +
    +graphnet.utilities.config.parsing.list_all_submodules(*packages)[source]
    +

    List all submodules in packages recursively.

    +
    +
    Return type:
    +

    List[ModuleType]

    +
    +
    Parameters:
    +

    packages (module) –

    +
    +
    +
    +
    +
    +graphnet.utilities.config.parsing.get_all_grapnet_classes(*packages)[source]
    +

    List all grapnet classes in packages.

    +
    +
    Return type:
    +

    Dict[str, type]

    +
    +
    Parameters:
    +

    packages (module) –

    +
    +
    +
    +
    +
    +graphnet.utilities.config.parsing.is_graphnet_module(obj)[source]
    +

    Return whether obj is a module in graphnet.

    +
    +
    Return type:
    +

    bool

    +
    +
    Parameters:
    +

    obj (module) –

    +
    +
    +
    +
    +
    +graphnet.utilities.config.parsing.is_graphnet_class(obj)[source]
    +

    Return whether obj is a class in graphnet.

    +
    +
    Return type:
    +

    bool

    +
    +
    Parameters:
    +

    obj (type) –

    +
    +
    +
    +
    +
    +graphnet.utilities.config.parsing.get_graphnet_classes(module)[source]
    +

    Return a lookup of all graphnet class names in module.

    +
    +
    Return type:
    +

    Dict[str, type]

    +
    +
    Parameters:
    +

    module (module) –

    +
    +
    +
    @@ -526,7 +685,7 @@

    parsingSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.config.training_config.html b/api/graphnet.utilities.config.training_config.html index bdb49dee2..26774e9f2 100644 --- a/api/graphnet.utilities.config.training_config.html +++ b/api/graphnet.utilities.config.training_config.html @@ -392,11 +392,81 @@ + +

  • @@ -465,7 +535,28 @@ @@ -475,8 +566,58 @@
    -
    -

    training_config

    +
    +

    training_config

    +

    Config classes for the graphnet.training module.

    +
    +
    +class graphnet.utilities.config.training_config.TrainingConfig(*, target, early_stopping_patience, fit, dataloader)[source]
    +

    Bases: BaseConfig

    +

    Configuration for all trainings.

    +

    Create a new model by parsing and validating input data from keyword arguments.

    +

    Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be +validated to form a valid model.

    +

    __init__ uses __pydantic_self__ instead of the more common self for the first arg to +allow self as a field name.

    +
    +
    Parameters:
    +
      +
    • target (str | List[str]) –

    • +
    • early_stopping_patience (int) –

    • +
    • fit (Dict[str, Any]) –

    • +
    • dataloader (Dict[str, Any]) –

    • +
    +
    +
    +
    +
    +target: Union[str, List[str]]
    +
    +
    +
    +early_stopping_patience: int
    +
    +
    +
    +fit: Dict[str, Any]
    +
    +
    +
    +dataloader: Dict[str, Any]
    +
    +
    +
    +model_config: ClassVar[ConfigDict] = {}
    +

    Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

    +
    +
    +
    +model_fields: ClassVar[dict[str, FieldInfo]] = {'dataloader': FieldInfo(annotation=Dict[str, Any], required=True), 'early_stopping_patience': FieldInfo(annotation=int, required=True), 'fit': FieldInfo(annotation=Dict[str, Any], required=True), 'target': FieldInfo(annotation=Union[str, List[str]], required=True)}
    +

    Metadata about the fields defined on the model, +mapping of field names to [FieldInfo][pydantic.fields.FieldInfo].

    +

    This replaces Model.__fields__ from Pydantic V1.

    +
    +
    @@ -526,7 +667,7 @@

    training_confi

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.decorators.html b/api/graphnet.utilities.decorators.html index 0203d3c40..92ae0d5b7 100644 --- a/api/graphnet.utilities.decorators.html +++ b/api/graphnet.utilities.decorators.html @@ -484,7 +484,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.filesys.html b/api/graphnet.utilities.filesys.html index 8e0c99feb..f13e120c5 100644 --- a/api/graphnet.utilities.filesys.html +++ b/api/graphnet.utilities.filesys.html @@ -604,7 +604,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.html b/api/graphnet.utilities.html index 847defb5c..5723c23e5 100644 --- a/api/graphnet.utilities.html +++ b/api/graphnet.utilities.html @@ -476,7 +476,10 @@
  • Logger
  • -
  • maths
  • +
  • maths +
  • @@ -528,7 +531,7 @@
    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.imports.html b/api/graphnet.utilities.imports.html index 7166dfe02..7ec9aa916 100644 --- a/api/graphnet.utilities.imports.html +++ b/api/graphnet.utilities.imports.html @@ -581,7 +581,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.logging.html b/api/graphnet.utilities.logging.html index a16ad7888..7972b4bad 100644 --- a/api/graphnet.utilities.logging.html +++ b/api/graphnet.utilities.logging.html @@ -831,7 +831,7 @@

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/graphnet.utilities.maths.html b/api/graphnet.utilities.maths.html index 9fc848fe1..8eb8048fa 100644 --- a/api/graphnet.utilities.maths.html +++ b/api/graphnet.utilities.maths.html @@ -393,11 +393,25 @@ + + @@ -422,7 +436,14 @@
    @@ -432,8 +453,22 @@
    -
    -

    maths

    +
    +

    maths

    +

    Collection of assorted “maths-like” functions.

    +
    +
    +graphnet.utilities.maths.eps_like(tensor)[source]
    +

    Return eps matching tensor’s dtype.

    +
    +
    Return type:
    +

    Tensor

    +
    +
    Parameters:
    +

    tensor (Tensor) –

    +
    +
    +
    @@ -483,7 +518,7 @@

    maths Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/api/modules.html b/api/modules.html index 6f49ef2ab..0bc19ad98 100644 --- a/api/modules.html +++ b/api/modules.html @@ -362,7 +362,7 @@

    srcSphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/contribute.html b/contribute.html index 0086e93f9..1a6ccae8d 100644 --- a/contribute.html +++ b/contribute.html @@ -484,7 +484,7 @@

    Code quality Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/genindex.html b/genindex.html index 360c7a1ef..7ef7dc30a 100644 --- a/genindex.html +++ b/genindex.html @@ -339,23 +339,44 @@

    Index

    | N | O | P + | Q | R | S | T | U + | V | W + | Z

    A

    - +
    @@ -363,10 +384,20 @@

    A

    B

    @@ -376,26 +407,62 @@

    C

    @@ -409,8 +476,14 @@

    D

  • DataConverter (class in graphnet.data.dataconverter)
  • - - +
  • default_prediction_labels (graphnet.models.task.classification.BinaryClassificationTask attribute) + +
  • +
  • default_target_labels (graphnet.models.task.classification.BinaryClassificationTask attribute) + +
  • + +

    E

    @@ -437,7 +624,23 @@

    E

    F

    - +
    @@ -478,12 +735,30 @@

    G

    - + -
    -
    • graphnet.data.extractors.utilities.types @@ -653,6 +982,13 @@

      G

    • +
    • + graphnet.data.pipeline + +
    • @@ -712,95 +1048,399 @@

      G

  • - graphnet.pisa + graphnet.deployment.i3modules.graphnet_module
  • - graphnet.pisa.fitting + graphnet.models
  • - graphnet.pisa.plotting + graphnet.models.coarsening
  • - graphnet.training + graphnet.models.components
  • +
    • - graphnet.training.weight_fitting + graphnet.models.components.layers
    • - graphnet.utilities + graphnet.models.components.pool
    • - graphnet.utilities.argparse + graphnet.models.detector
    • - graphnet.utilities.decorators + graphnet.models.detector.detector
    • - graphnet.utilities.filesys + graphnet.models.detector.icecube
    • - graphnet.utilities.imports + graphnet.models.detector.prometheus
    • - graphnet.utilities.logging + graphnet.models.gnn
    • -
    +
  • + graphnet.models.gnn.convnet -

    H

    - - - +
    +
  • + graphnet.models.gnn.dynedge_jinst + +
  • +
  • + graphnet.models.gnn.dynedge_kaggle_tito + +
  • +
  • + graphnet.models.gnn.gnn + +
  • +
  • + graphnet.models.graphs + +
  • +
  • + graphnet.models.graphs.edges + +
  • +
  • + graphnet.models.graphs.edges.edges + +
  • +
  • + graphnet.models.graphs.graph_definition + +
  • +
  • + graphnet.models.graphs.graphs + +
  • +
  • + graphnet.models.graphs.nodes + +
  • +
  • + graphnet.models.graphs.nodes.nodes + +
  • +
  • + graphnet.models.model + +
  • +
  • + graphnet.models.standard_model + +
  • +
  • + graphnet.models.task + +
  • +
  • + graphnet.models.task.classification + +
  • +
  • + graphnet.models.task.reconstruction + +
  • +
  • + graphnet.models.task.task + +
  • +
  • + graphnet.models.utils + +
  • +
  • + graphnet.pisa + +
  • +
  • + graphnet.pisa.fitting + +
  • +
  • + graphnet.pisa.plotting + +
  • +
  • + graphnet.training + +
  • +
  • + graphnet.training.callbacks + +
  • +
  • + graphnet.training.labels + +
  • +
  • + graphnet.training.loss_functions + +
  • +
  • + graphnet.training.utils + +
  • +
  • + graphnet.training.weight_fitting + +
  • +
  • + graphnet.utilities + +
  • +
  • + graphnet.utilities.argparse + +
  • +
  • + graphnet.utilities.config + +
  • +
  • + graphnet.utilities.config.base_config + +
  • +
  • + graphnet.utilities.config.configurable + +
  • +
  • + graphnet.utilities.config.dataset_config + +
  • +
  • + graphnet.utilities.config.model_config + +
  • +
  • + graphnet.utilities.config.parsing + +
  • +
  • + graphnet.utilities.config.training_config + +
  • +
  • + graphnet.utilities.decorators + +
  • +
  • + graphnet.utilities.filesys + +
  • +
  • + graphnet.utilities.imports + +
  • +
  • + graphnet.utilities.logging + +
  • +
  • + graphnet.utilities.maths + +
  • +
  • GraphNeTI3Module (class in graphnet.deployment.i3modules.graphnet_module) +
  • +
  • group_by() (in module graphnet.models.components.pool) +
  • +
  • group_pulses_to_dom() (in module graphnet.models.components.pool) +
  • +
  • group_pulses_to_pmt() (in module graphnet.models.components.pool) +
  • +
    + +

    H

    + + + - + +
    +
  • key (graphnet.training.labels.Label property) +
  • +

    L

    +
    @@ -904,6 +1626,10 @@

    L

    M

    +

    N

    @@ -1011,9 +1875,59 @@

    N

    @@ -1021,6 +1935,12 @@

    N

    O

    + @@ -1032,21 +1952,69 @@

    P

  • pairwise_shuffle() (in module graphnet.data.utilities.random)
  • ParquetDataConverter (class in graphnet.data.parquet.parquet_dataconverter) +
  • +
  • ParquetDataset (class in graphnet.data.dataset.parquet.parquet_dataset)
  • ParquetToSQLiteConverter (class in graphnet.data.utilities.parquet_to_sqlite) +
  • +
  • parse_graph_definition() (in module graphnet.data.dataset.dataset) +
  • +
  • path (graphnet.data.dataset.dataset.Dataset property) + +
  • +
  • PiecewiseLinearLR (class in graphnet.training.callbacks)
  • plot_1D_contour() (in module graphnet.pisa.plotting)
  • - - + +
    + +

    Q

    + + @@ -1055,7 +2023,11 @@

    P

    R

    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
        graphnet.data.dataconverter
        + graphnet.data.dataloader +
        + graphnet.data.dataset +
        + graphnet.data.dataset.dataset +
        + graphnet.data.dataset.parquet +
        + graphnet.data.dataset.parquet.parquet_dataset +
        + graphnet.data.dataset.sqlite +
        + graphnet.data.dataset.sqlite.sqlite_dataset +
        + graphnet.data.dataset.sqlite.sqlite_dataset_perturbed +
        @@ -455,6 +495,11 @@

    Python Module Index

        graphnet.data.parquet.parquet_dataconverter
        + graphnet.data.pipeline +
        @@ -495,6 +540,156 @@

    Python Module Index

        graphnet.deployment
        + graphnet.deployment.i3modules.graphnet_module +
        + graphnet.models +
        + graphnet.models.coarsening +
        + graphnet.models.components +
        + graphnet.models.components.layers +
        + graphnet.models.components.pool +
        + graphnet.models.detector +
        + graphnet.models.detector.detector +
        + graphnet.models.detector.icecube +
        + graphnet.models.detector.prometheus +
        + graphnet.models.gnn +
        + graphnet.models.gnn.convnet +
        + graphnet.models.gnn.dynedge +
        + graphnet.models.gnn.dynedge_jinst +
        + graphnet.models.gnn.dynedge_kaggle_tito +
        + graphnet.models.gnn.gnn +
        + graphnet.models.graphs +
        + graphnet.models.graphs.edges +
        + graphnet.models.graphs.edges.edges +
        + graphnet.models.graphs.graph_definition +
        + graphnet.models.graphs.graphs +
        + graphnet.models.graphs.nodes +
        + graphnet.models.graphs.nodes.nodes +
        + graphnet.models.model +
        + graphnet.models.standard_model +
        + graphnet.models.task +
        + graphnet.models.task.classification +
        + graphnet.models.task.reconstruction +
        + graphnet.models.task.task +
        + graphnet.models.utils +
        @@ -515,6 +710,26 @@

    Python Module Index

        graphnet.training
        + graphnet.training.callbacks +
        + graphnet.training.labels +
        + graphnet.training.loss_functions +
        + graphnet.training.utils +
        @@ -530,6 +745,41 @@

    Python Module Index

        graphnet.utilities.argparse
        + graphnet.utilities.config +
        + graphnet.utilities.config.base_config +
        + graphnet.utilities.config.configurable +
        + graphnet.utilities.config.dataset_config +
        + graphnet.utilities.config.model_config +
        + graphnet.utilities.config.parsing +
        + graphnet.utilities.config.training_config +
        @@ -550,6 +800,11 @@

    Python Module Index

        graphnet.utilities.logging
        + graphnet.utilities.maths +
    @@ -575,7 +830,7 @@

    Python Module Index

  • Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/search.html b/search.html index 32901805c..4199f1026 100644 --- a/search.html +++ b/search.html @@ -361,7 +361,7 @@

    Search

    Created using - Sphinx 7.2.5. + Sphinx 7.2.6. and Material for Sphinx diff --git a/searchindex.js b/searchindex.js index 4df38d653..200c93687 100644 --- a/searchindex.js +++ b/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 36, 37, 38, 39, 40, 41, 75, 76, 77, 82, 83, 84, 93, 94, 95, 98, 99, 100], "i": [0, 1, 15, 17, 28, 29, 30, 35, 36, 39, 40, 76, 82, 84, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 35, 40, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 30, 36, 39, 75, 76, 83, 93, 94, 99], "perform": [0, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 41, 45, 69, 99], "task": [0, 1, 45, 98, 99], "neutrino": [0, 1, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 15, 20, 25, 27, 28, 32, 35, 36, 37, 38, 40, 41, 75, 82, 83, 84, 94, 95, 98, 99, 100], "graph": [0, 1, 45, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 99], "gnn": [0, 1, 45, 99, 100], "make": [0, 5, 82, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 40, 41, 82, 84, 97, 99, 100], "complex": [0, 99], "model": [0, 1, 41, 76, 77, 84, 97, 99, 100], "can": [0, 1, 15, 17, 20, 38, 75, 76, 82, 84, 98, 99, 100], "event": [0, 1, 22, 36, 38, 40, 75, 82, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 99], "configur": [0, 1, 75, 83, 85, 95, 99], "infer": [0, 1, 41, 99, 100], "time": [0, 4, 36, 95, 99, 100], "ar": [0, 1, 4, 5, 17, 30, 32, 35, 38, 40, 75, 82, 98, 99, 100], "order": [0, 28, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 76, 99], "group": [0, 5, 32, 35, 99], "increas": [0, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 99], "code": [0, 25, 36, 99], "contribut": [0, 99, 100], "from": [0, 1, 14, 15, 17, 19, 20, 22, 28, 29, 30, 35, 38, 76, 95, 98, 99, 100], "build": [0, 1, 99], "gener": [0, 5, 17, 99], "reusabl": [0, 99], "softwar": [0, 99], "packag": [0, 1, 39, 93, 94, 98, 99, 100], "base": [0, 4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 98, 99], "The": [0, 5, 28, 30, 35, 36, 75, 76, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 99], "yield": [0, 75, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "g": [0, 1, 5, 25, 28, 30, 32, 35, 36, 40, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 50, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 98, 99], "wa": [0, 99], "appli": [0, 15, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 93, 99, 100], "great": [0, 99], "point": [0, 24, 99], "analys": [0, 41, 74, 99], "final": [0, 99], "millisecond": [0, 99], "allow": [0, 41, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 98, 99], "type": [0, 5, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 15, 16, 25, 29, 40, 75, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 25, 99], "10": [0, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 99], "direct": [0, 99], "real": [0, 99], "thi": [0, 3, 5, 15, 17, 30, 32, 35, 36, 39, 75, 76, 82, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 32, 35, 40, 84, 99], "modul": [0, 3, 30, 41, 74, 77, 83, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 99], "raw": [0, 99], "data": [0, 1, 4, 5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 37, 38, 39, 40, 84, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 16, 30, 31, 32, 34, 35, 36, 41, 98, 99, 100], "format": [0, 1, 3, 5, 28, 32, 35, 76, 98, 99, 100], "deploi": [0, 1, 41, 99], "chain": [0, 1, 41, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 25, 36, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 32, 35, 40, 84, 98, 99], "intermedi": [0, 1, 3, 5, 32, 35, 99, 100], "file": [0, 1, 3, 5, 15, 28, 32, 35, 38, 39, 75, 84, 93, 95, 99, 100], "read": [0, 3, 28, 99, 100], "simpl": [0, 99], "physic": [0, 1, 15, 29, 30, 41, 99], "orient": [0, 99], "compon": [0, 1, 45, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 77, 83, 99, 100], "deploy": [0, 1, 42, 97, 99], "modular": [0, 99], "subclass": [0, 99], "torch": [0, 94, 99, 100], "nn": [0, 99], "mean": [0, 5, 32, 35, 99], "onli": [0, 1, 75, 82, 94, 99, 100], "need": [0, 28, 99, 100], "import": [0, 1, 36, 83, 99], "few": [0, 98, 99], "exist": [0, 35, 36, 99], "purpos": [0, 99], "built": [0, 99], "them": [0, 1, 28, 75, 99, 100], "togeth": [0, 99], "form": [0, 99], "complet": [0, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 99], "layer": [0, 45, 47, 99], "connect": [0, 99], "etc": [0, 95, 99], "optimis": [0, 1, 99], "differ": [0, 15, 98, 99, 100], "track": [0, 15, 19, 98, 99], "These": [0, 98, 99], "prepar": [0, 99], "satisfi": [0, 99], "o": [0, 99], "load": [0, 39, 99], "requir": [0, 21, 36, 99, 100], "when": [0, 5, 28, 32, 35, 36, 95, 98, 99, 100], "batch": [0, 84, 99], "do": [0, 98, 99, 100], "predict": [0, 20, 24, 26, 99], "either": [0, 99, 100], "contain": [0, 5, 28, 29, 32, 35, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 99], "split": [0, 99], "up": [0, 5, 32, 35, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 99], "continu": [0, 99], "expand": [0, 99], "": [0, 5, 15, 28, 35, 38, 82, 84, 95, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 17, 28, 30, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 28, 32, 35, 40, 82, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 98], "leverag": 1, "advanc": 1, "machin": [1, 100], "learn": [1, 100], "without": [1, 75, 100], "have": [1, 5, 17, 32, 35, 36, 40, 98, 100], "expert": 1, "themselv": 1, "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 15, 17, 32, 35, 36, 95, 98, 100], "streamlin": 1, "process": [1, 5, 15, 98, 100], "transform": [1, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 30, 37, 83, 84, 95], "variou": 1, "easili": 1, "architectur": 1, "main": [1, 98, 100], "featur": [1, 3, 4, 5, 16, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 93, 100], "more": [1, 36, 39, 95], "index": [1, 5, 30, 36], "sqlite": [1, 3, 7, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 98], "i3modul": [1, 41], "includ": [1, 75, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 19, 40, 84], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35], "parquet": [1, 3, 7, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 93, 94, 95, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3], "pipelin": [1, 3], "coarsen": [1, 45], "standard_model": [1, 45], "pisa": [1, 21, 75, 76, 94, 97, 100], "fit": [1, 74, 76, 82], "plot": [1, 74], "callback": [1, 77], "label": [1, 19, 22, 76, 77], "loss_funct": [1, 77], "weight_fit": [1, 77], "config": [1, 40, 75, 83, 84], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85], "global": [2, 4], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 37, 40], "string_selection_resolv": [3, 37], "truth": [3, 4, 16, 25, 36, 82], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "class": [4, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 34, 35, 38, 40, 75, 82, 84, 95, 98], "object": [4, 5, 15, 17, 28, 30, 75, 84, 95], "namespac": 4, "name": [4, 5, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 35, 36, 38, 75, 82, 84, 95, 98, 100], "icecube86": 4, "dom_x": 4, "dom_i": 4, "dom_z": 4, "dom_tim": 4, "charg": 4, "rde": 4, "pmt_area": 4, "deepcor": [4, 16], "upgrad": [4, 16, 100], "string": [4, 5, 28, 32, 35, 40], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 100], "kaggl": 4, "x": [4, 5, 25, 32, 35, 76, 82], "y": [4, 25, 76, 100], "z": [4, 5, 25, 32, 35, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": 4, "zenith": 4, "pid": [4, 40], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": 4, "inelast": 4, "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 82, 84, 93, 95], "gcd_file": [5, 15], "paramet": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 94], "pool": [5, 45, 47], "worker": [5, 32, 35, 39, 84, 95], "return": [5, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 75, 76, 82, 84, 93, 94, 95], "none": [5, 15, 17, 25, 29, 30, 32, 35, 36, 38, 40, 75, 82, 84, 93, 95], "synchron": 5, "list": [5, 15, 17, 25, 28, 30, 32, 35, 36, 38, 39, 40, 76, 82, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 75, 82, 100], "typevar": 5, "f": 5, "bound": [5, 76], "callabl": [5, 30, 82, 94], "ani": [5, 28, 29, 30, 32, 35, 76, 82, 84, 95, 100], "outdir": [5, 32, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 32, 35, 36, 40, 75, 82], "icetray_verbos": [5, 32, 35], "abc": [5, 15, 82], "logger": [5, 15, 38, 40, 82, 83, 95, 100], "construct": [5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 75, 82, 84, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35], "accord": [5, 32, 35], "match": [5, 32, 35, 82, 93], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 95], "input": [5, 32, 35], "replac": [5, 32, 35], "period": [5, 32, 35], "special": [5, 17, 32, 35], "interpret": [5, 32, 35], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35], "instanc": [5, 15, 25, 30, 32, 35, 75, 100], "A": [5, 32, 35, 75, 82, 100], "_": [5, 32, 35], "0": [5, 32, 35, 40, 75, 76], "9": [5, 32, 35], "5": [5, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35], "one": [5, 32, 35, 36, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 19, 22, 32, 35, 40, 75, 82, 84, 95], "properti": [5, 15, 20, 30, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 15, 27, 28, 29, 30, 32, 35, 82], "set": [5, 17, 98], "inherit": [5, 15, 30, 95], "path": [5, 36, 39, 75, 76, 84, 93, 100], "correspond": [5, 28, 30, 35, 39, 82, 93, 100], "gcd": [5, 15, 29, 39, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 75, 82, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 100], "result": [5, 32, 35, 100], "option": [5, 25, 32, 35, 75, 76, 82, 83, 84, 93, 100], "default": [5, 17, 25, 28, 32, 35, 36, 38, 75, 76, 82, 84, 93], "current": [5, 32, 35, 40, 98, 100], "rais": [5, 17, 32], "notimplementederror": [5, 32], "If": [5, 17, 32, 35, 75, 82, 98, 100], "been": [5, 32, 98], "backend": [5, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 16, 17, 35, 36], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 29, 30, 75, 76, 84], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "collect": [14, 15, 27], "i3fram": [14, 15, 17, 29, 30], "frame": [14, 15, 17, 27, 30, 35], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "inform": [15, 17, 25, 76], "should": [15, 28, 40, 98, 100], "__call__": 15, "icetrai": [15, 29, 30, 94], "keep": 15, "proven": 15, "tabl": [15, 35, 36, 75, 82], "set_fil": 15, "store": [15, 36, 75], "refer": 15, "being": 15, "get": [15, 29, 100], "multipl": [15, 95], "treat": 15, "singl": 15, "pulsemap": [16, 35], "puls": [16, 17, 29, 30, 35, 36], "seri": [16, 17, 29, 30, 36], "86": 16, "nois": [16, 29], "flag": 16, "ad": [16, 75], "kei": [17, 28, 29, 30, 35, 36], "exclude_kei": 17, "dynam": 17, "pars": [17, 76, 83, 84, 85], "call": [17, 30, 35, 75, 82, 95], "tri": [17, 30], "automat": [17, 98], "cast": [17, 30], "done": [17, 95, 98], "recurs": [17, 30, 93], "each": [17, 28, 30, 36, 38, 39, 75, 76, 93], "look": [17, 100], "member": [17, 30, 95], "variabl": [17, 30, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "dict": [17, 28, 30, 35, 75, 76, 84], "handl": [17, 84, 95], "hand": 17, "case": [17, 100], "per": [17, 36, 82], "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": 17, "hybrid": 18, "galatict": 18, "plane": 18, "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 98], "algorithm": 20, "comparison": 20, "quantiti": 21, "select": [22, 40, 82, 98], "queso": 22, "retro": 23, "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": 25, "particl": [25, 36], "start": [25, 98, 100], "stop": [25, 84], "within": 25, "hard": 25, "i3mctre": 25, "valu": [25, 28, 35, 36, 76, 84], "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": 28, "obj": [28, 30], "parent_kei": 28, "flatten": 28, "nest": 28, "dictionari": [28, 29, 30, 35, 75, 76], "non": [28, 30, 35, 36], "exampl": [28, 40, 100], "d": [28, 98], "b": 28, "c": [28, 100], "2": [28, 75, 76, 100], "a__b": 28, "applic": 28, "combin": 28, "parent": 28, "__": [28, 30], "concaten": 28, "nester": 28, "json": 28, "therefor": 28, "we": [28, 30, 40, 98, 100], "element": [28, 30], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "check": [29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [29, 30, 35, 36, 93, 94], "mont": 29, "carlo": 29, "simul": 29, "bool": [29, 30, 35, 36, 40, 75, 82, 84, 93, 94, 95], "pulseseri": 29, "calibr": [29, 30], "indici": [29, 40], "gcd_dict": [29, 30], "p": [29, 35], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "fn": 30, "ensur": [30, 39, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 84], "ignor": 30, "mangl": 30, "take": [30, 35, 98], "mainli": 30, "cannot": 30, "trivial": 30, "doe": 30, "try": 30, "length": 30, "equival": 30, "its": 30, "like": [30, 98], "otherwis": 30, "itself": 30, "deem": 30, "wai": [30, 40, 98, 100], "represent": 30, "optic": 30, "found": 30, "parquetdataconvert": [31, 32], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": 35, "databas": [35, 36, 38, 75, 82, 100], "max_table_s": 35, "maximum": [35, 84], "row": [35, 36], "given": [35, 82, 84], "exce": 35, "limit": 35, "creat": [35, 36, 98, 100], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": 35, "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 75, 82], "becaus": [35, 39], "instead": 35, "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 75, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 82, 98], "alreadi": [36, 100], "attach": 36, "queri": [36, 40], "column": [36, 75, 82], "default_typ": 36, "null": 36, "integer_primary_kei": 36, "event_no": [36, 40, 82], "NOT": 36, "integ": 36, "primari": 36, "Such": 36, "uniqu": [36, 38], "appropri": 36, "expect": [36, 40], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "assign": [38, 98], "id": 38, "everi": [38, 100], "field": [38, 76], "One": [38, 76], "choos": 38, "argument": [38, 82, 84], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "directori": [38, 75, 93], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "shuffl": 39, "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "resolv": 40, "indic": [40, 84, 98], "seed": 40, "use_cach": 40, "datasetconfig": 40, "flexibl": 40, "defin": 40, "below": [40, 76, 82, 98, 100], "show": 40, "involv": 40, "cover": 40, "yml": [40, 84], "test": [40, 94, 98], "50000": 40, "ab": 40, "12": 40, "14": 40, "16": 40, "13": [40, 100], "10000": 40, "compat": 40, "syntax": 40, "mai": [40, 100], "also": 40, "specifi": [40, 76, 100], "fix": 40, "randomli": 40, "20": [40, 95], "graphnet_modul": [41, 42], "convnet": [45, 54], "dynedg": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 60], "node": [45, 60], "graph_definit": [45, 60], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "updat": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "truth_tabl": [75, 82], "statistical_fit": 75, "weight": [75, 82, 100], "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "self": 75, "_database_path": 75, "statist": 75, "effect": [75, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "chang": [75, 98], "assumpt": 75, "regard": 75, "fals": [75, 82], "two": 75, "pipeline_path": 75, "post_fix": 75, "model_nam": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "n_worker": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "3": [75, 76, 98, 100], "7": 75, "1d": [75, 76], "float": [75, 76], "fit_2d_contour": 75, "2d": [75, 76], "entri": [76, 84], "content": 76, "contour_data": 76, "xlim": 76, "4": 76, "6": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "note": 76, "right": 76, "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 82, 100], "352": 76, "uniform": [77, 82], "bjoernlow": [77, 82], "produc": 82, "public": 82, "uniformweightfitt": 82, "bin": 82, "kwarg": [82, 95], "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "space": 82, "np": 82, "log10": 82, "happen": 82, "addit": 82, "pass": [82, 98], "distribut": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "model_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "consist": [84, 95, 98], "cli": 84, "present": [84, 93, 94], "pop_default": 84, "remov": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "size": 84, "128": 84, "help": [84, 98], "home": [84, 100], "runner": 84, "local": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "epoch": 84, "loss": 84, "after": 84, "gpu": [84, 100], "narg": 84, "max": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "arg": [84, 95], "add": [84, 98, 100], "overwritten": 84, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "out": [95, 98, 100], "repeat": 95, "messag": 95, "nb_repeats_allow": 95, "record": 95, "print": 95, "logrecord": 95, "class_nam": 95, "log_fold": 95, "clear": 95, "intuit": 95, "composit": 95, "rather": 95, "loggeradapt": 95, "chosen": 95, "avoid": [95, 98], "clash": 95, "pytorch_lightn": 95, "lightningmodul": 95, "setlevel": 95, "deleg": 95, "msg": 95, "error": [95, 98], "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "exactli": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "you": [98, 100], "place": 98, "describ": 98, "altern": 98, "yourself": 98, "ownership": 98, "particular": 98, "activ": [98, 100], "transpar": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "solut": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "your": [98, 100], "repositori": 98, "graphdefinit": 98, "euclidean": 98, "definit": 98, "own": [98, 100], "team": 98, "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "clean": [98, 100], "see": [98, 100], "version": [98, 100], "8": [98, 100], "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "tag": [98, 100], "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "static": 98, "concept": 98, "http": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "In": 100, "runtim": 100, "achiev": 100, "bash": 100, "shell": 100, "eval": 100, "cvmf": 100, "opensciencegrid": 100, "org": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "v1": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "pytorch": 100, "geometr": 100, "just": 100, "won": 100, "later": 100, "don": 100, "r": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "github": 100, "com": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "access": 100, "so": 100, "re": 100, "intend": 100, "consid": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "py": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "__init__": 100, "write": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "root": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "100": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[82, 0, 0, "-", "weight_fitting"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "dataloader"]], "dataset": [[7, "dataset"], [8, "dataset"]], "parquet": [[9, "parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "parquet-dataset"]], "sqlite": [[11, "sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "sqlite-dataset"]], "sqlite_dataset_perturbed": [[13, "sqlite-dataset-perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "graphnet-module"]], "models": [[45, "models"]], "coarsening": [[46, "coarsening"]], "components": [[47, "components"]], "layers": [[48, "layers"]], "pool": [[49, "pool"]], "detector": [[50, "detector"], [51, "detector"]], "icecube": [[52, "icecube"]], "prometheus": [[53, "prometheus"]], "gnn": [[54, "gnn"], [59, "gnn"]], "convnet": [[55, "convnet"]], "dynedge": [[56, "dynedge"]], "dynedge_jinst": [[57, "dynedge-jinst"]], "dynedge_kaggle_tito": [[58, "dynedge-kaggle-tito"]], "graphs": [[60, "graphs"], [64, "graphs"]], "edges": [[61, "edges"], [62, "edges"]], "graph_definition": [[63, "graph-definition"]], "nodes": [[65, "nodes"], [66, "nodes"]], "model": [[67, "model"]], "standard_model": [[68, "standard-model"]], "task": [[69, "task"], [72, "task"]], "classification": [[70, "classification"]], "reconstruction": [[71, "reconstruction"]], "utils": [[73, "utils"], [81, "utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "callbacks"]], "labels": [[79, "labels"]], "loss_functions": [[80, "loss-functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "config"]], "base_config": [[86, "base-config"]], "configurable": [[87, "configurable"]], "dataset_config": [[88, "dataset-config"]], "model_config": [[89, "model-config"]], "parsing": [[90, "parsing"]], "training_config": [[91, "training-config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["about", "api/graphnet", "api/graphnet.constants", "api/graphnet.data", "api/graphnet.data.constants", "api/graphnet.data.dataconverter", "api/graphnet.data.dataloader", "api/graphnet.data.dataset", "api/graphnet.data.dataset.dataset", "api/graphnet.data.dataset.parquet", "api/graphnet.data.dataset.parquet.parquet_dataset", "api/graphnet.data.dataset.sqlite", "api/graphnet.data.dataset.sqlite.sqlite_dataset", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed", "api/graphnet.data.extractors", "api/graphnet.data.extractors.i3extractor", "api/graphnet.data.extractors.i3featureextractor", "api/graphnet.data.extractors.i3genericextractor", "api/graphnet.data.extractors.i3hybridrecoextractor", "api/graphnet.data.extractors.i3ntmuonlabelsextractor", "api/graphnet.data.extractors.i3particleextractor", "api/graphnet.data.extractors.i3pisaextractor", "api/graphnet.data.extractors.i3quesoextractor", "api/graphnet.data.extractors.i3retroextractor", "api/graphnet.data.extractors.i3splinempeextractor", "api/graphnet.data.extractors.i3truthextractor", "api/graphnet.data.extractors.i3tumextractor", "api/graphnet.data.extractors.utilities", "api/graphnet.data.extractors.utilities.collections", "api/graphnet.data.extractors.utilities.frames", "api/graphnet.data.extractors.utilities.types", "api/graphnet.data.parquet", "api/graphnet.data.parquet.parquet_dataconverter", "api/graphnet.data.pipeline", "api/graphnet.data.sqlite", "api/graphnet.data.sqlite.sqlite_dataconverter", "api/graphnet.data.sqlite.sqlite_utilities", "api/graphnet.data.utilities", "api/graphnet.data.utilities.parquet_to_sqlite", "api/graphnet.data.utilities.random", "api/graphnet.data.utilities.string_selection_resolver", "api/graphnet.deployment", "api/graphnet.deployment.i3modules", "api/graphnet.deployment.i3modules.deployer", "api/graphnet.deployment.i3modules.graphnet_module", "api/graphnet.models", "api/graphnet.models.coarsening", "api/graphnet.models.components", "api/graphnet.models.components.layers", "api/graphnet.models.components.pool", "api/graphnet.models.detector", "api/graphnet.models.detector.detector", "api/graphnet.models.detector.icecube", "api/graphnet.models.detector.prometheus", "api/graphnet.models.gnn", "api/graphnet.models.gnn.convnet", "api/graphnet.models.gnn.dynedge", "api/graphnet.models.gnn.dynedge_jinst", "api/graphnet.models.gnn.dynedge_kaggle_tito", "api/graphnet.models.gnn.gnn", "api/graphnet.models.graphs", "api/graphnet.models.graphs.edges", "api/graphnet.models.graphs.edges.edges", "api/graphnet.models.graphs.graph_definition", "api/graphnet.models.graphs.graphs", "api/graphnet.models.graphs.nodes", "api/graphnet.models.graphs.nodes.nodes", "api/graphnet.models.model", "api/graphnet.models.standard_model", "api/graphnet.models.task", "api/graphnet.models.task.classification", "api/graphnet.models.task.reconstruction", "api/graphnet.models.task.task", "api/graphnet.models.utils", "api/graphnet.pisa", "api/graphnet.pisa.fitting", "api/graphnet.pisa.plotting", "api/graphnet.training", "api/graphnet.training.callbacks", "api/graphnet.training.labels", "api/graphnet.training.loss_functions", "api/graphnet.training.utils", "api/graphnet.training.weight_fitting", "api/graphnet.utilities", "api/graphnet.utilities.argparse", "api/graphnet.utilities.config", "api/graphnet.utilities.config.base_config", "api/graphnet.utilities.config.configurable", "api/graphnet.utilities.config.dataset_config", "api/graphnet.utilities.config.model_config", "api/graphnet.utilities.config.parsing", "api/graphnet.utilities.config.training_config", "api/graphnet.utilities.decorators", "api/graphnet.utilities.filesys", "api/graphnet.utilities.imports", "api/graphnet.utilities.logging", "api/graphnet.utilities.maths", "api/modules", "contribute", "index", "install"], "filenames": ["about.md", "api/graphnet.rst", "api/graphnet.constants.rst", "api/graphnet.data.rst", "api/graphnet.data.constants.rst", "api/graphnet.data.dataconverter.rst", "api/graphnet.data.dataloader.rst", "api/graphnet.data.dataset.rst", "api/graphnet.data.dataset.dataset.rst", "api/graphnet.data.dataset.parquet.rst", "api/graphnet.data.dataset.parquet.parquet_dataset.rst", "api/graphnet.data.dataset.sqlite.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset.rst", "api/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.rst", "api/graphnet.data.extractors.rst", "api/graphnet.data.extractors.i3extractor.rst", "api/graphnet.data.extractors.i3featureextractor.rst", "api/graphnet.data.extractors.i3genericextractor.rst", "api/graphnet.data.extractors.i3hybridrecoextractor.rst", "api/graphnet.data.extractors.i3ntmuonlabelsextractor.rst", "api/graphnet.data.extractors.i3particleextractor.rst", "api/graphnet.data.extractors.i3pisaextractor.rst", "api/graphnet.data.extractors.i3quesoextractor.rst", "api/graphnet.data.extractors.i3retroextractor.rst", "api/graphnet.data.extractors.i3splinempeextractor.rst", "api/graphnet.data.extractors.i3truthextractor.rst", "api/graphnet.data.extractors.i3tumextractor.rst", "api/graphnet.data.extractors.utilities.rst", "api/graphnet.data.extractors.utilities.collections.rst", "api/graphnet.data.extractors.utilities.frames.rst", "api/graphnet.data.extractors.utilities.types.rst", "api/graphnet.data.parquet.rst", "api/graphnet.data.parquet.parquet_dataconverter.rst", "api/graphnet.data.pipeline.rst", "api/graphnet.data.sqlite.rst", "api/graphnet.data.sqlite.sqlite_dataconverter.rst", "api/graphnet.data.sqlite.sqlite_utilities.rst", "api/graphnet.data.utilities.rst", "api/graphnet.data.utilities.parquet_to_sqlite.rst", "api/graphnet.data.utilities.random.rst", "api/graphnet.data.utilities.string_selection_resolver.rst", "api/graphnet.deployment.rst", "api/graphnet.deployment.i3modules.rst", "api/graphnet.deployment.i3modules.deployer.rst", "api/graphnet.deployment.i3modules.graphnet_module.rst", "api/graphnet.models.rst", "api/graphnet.models.coarsening.rst", "api/graphnet.models.components.rst", "api/graphnet.models.components.layers.rst", "api/graphnet.models.components.pool.rst", "api/graphnet.models.detector.rst", "api/graphnet.models.detector.detector.rst", "api/graphnet.models.detector.icecube.rst", "api/graphnet.models.detector.prometheus.rst", "api/graphnet.models.gnn.rst", "api/graphnet.models.gnn.convnet.rst", "api/graphnet.models.gnn.dynedge.rst", "api/graphnet.models.gnn.dynedge_jinst.rst", "api/graphnet.models.gnn.dynedge_kaggle_tito.rst", "api/graphnet.models.gnn.gnn.rst", "api/graphnet.models.graphs.rst", "api/graphnet.models.graphs.edges.rst", "api/graphnet.models.graphs.edges.edges.rst", "api/graphnet.models.graphs.graph_definition.rst", "api/graphnet.models.graphs.graphs.rst", "api/graphnet.models.graphs.nodes.rst", "api/graphnet.models.graphs.nodes.nodes.rst", "api/graphnet.models.model.rst", "api/graphnet.models.standard_model.rst", "api/graphnet.models.task.rst", "api/graphnet.models.task.classification.rst", "api/graphnet.models.task.reconstruction.rst", "api/graphnet.models.task.task.rst", "api/graphnet.models.utils.rst", "api/graphnet.pisa.rst", "api/graphnet.pisa.fitting.rst", "api/graphnet.pisa.plotting.rst", "api/graphnet.training.rst", "api/graphnet.training.callbacks.rst", "api/graphnet.training.labels.rst", "api/graphnet.training.loss_functions.rst", "api/graphnet.training.utils.rst", "api/graphnet.training.weight_fitting.rst", "api/graphnet.utilities.rst", "api/graphnet.utilities.argparse.rst", "api/graphnet.utilities.config.rst", "api/graphnet.utilities.config.base_config.rst", "api/graphnet.utilities.config.configurable.rst", "api/graphnet.utilities.config.dataset_config.rst", "api/graphnet.utilities.config.model_config.rst", "api/graphnet.utilities.config.parsing.rst", "api/graphnet.utilities.config.training_config.rst", "api/graphnet.utilities.decorators.rst", "api/graphnet.utilities.filesys.rst", "api/graphnet.utilities.imports.rst", "api/graphnet.utilities.logging.rst", "api/graphnet.utilities.maths.rst", "api/modules.rst", "contribute.md", "index.rst", "install.md"], "titles": ["About", "API", "constants", "data", "constants", "dataconverter", "dataloader", "dataset", "dataset", "parquet", "parquet_dataset", "sqlite", "sqlite_dataset", "sqlite_dataset_perturbed", "extractors", "i3extractor", "i3featureextractor", "i3genericextractor", "i3hybridrecoextractor", "i3ntmuonlabelsextractor", "i3particleextractor", "i3pisaextractor", "i3quesoextractor", "i3retroextractor", "i3splinempeextractor", "i3truthextractor", "i3tumextractor", "utilities", "collections", "frames", "types", "parquet", "parquet_dataconverter", "pipeline", "sqlite", "sqlite_dataconverter", "sqlite_utilities", "utilities", "parquet_to_sqlite", "random", "string_selection_resolver", "deployment", "i3modules", "deployer", "graphnet_module", "models", "coarsening", "components", "layers", "pool", "detector", "detector", "icecube", "prometheus", "gnn", "convnet", "dynedge", "dynedge_jinst", "dynedge_kaggle_tito", "gnn", "graphs", "edges", "edges", "graph_definition", "graphs", "nodes", "nodes", "model", "standard_model", "task", "classification", "reconstruction", "task", "utils", "pisa", "fitting", "plotting", "training", "callbacks", "labels", "loss_functions", "utils", "weight_fitting", "utilities", "argparse", "config", "base_config", "configurable", "dataset_config", "model_config", "parsing", "training_config", "decorators", "filesys", "imports", "logging", "maths", "src", "Contribute", "About", "Install"], "terms": {"graphnet": [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 36, 37, 38, 39, 40, 41, 44, 45, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99, 100], "i": [0, 1, 8, 10, 12, 13, 15, 17, 28, 29, 30, 35, 36, 39, 40, 44, 46, 49, 55, 56, 62, 66, 70, 71, 72, 73, 76, 78, 79, 80, 82, 84, 89, 90, 93, 94, 95, 98, 99, 100], "an": [0, 5, 30, 32, 33, 35, 40, 44, 63, 80, 93, 95, 98, 99, 100], "open": [0, 98, 99], "sourc": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 98, 99], "python": [0, 1, 5, 14, 15, 17, 28, 30, 98, 99, 100], "framework": [0, 99], "aim": [0, 1, 98, 99], "provid": [0, 1, 8, 10, 12, 13, 44, 45, 80, 98, 99, 100], "high": [0, 99], "qualiti": [0, 99], "user": [0, 45, 78, 99, 100], "friendli": [0, 99], "end": [0, 1, 5, 32, 35, 99], "function": [0, 5, 6, 8, 30, 36, 39, 44, 46, 49, 52, 53, 63, 67, 70, 71, 72, 73, 75, 76, 80, 81, 83, 88, 89, 90, 93, 94, 96, 99], "perform": [0, 46, 48, 49, 54, 56, 58, 68, 70, 71, 72, 99], "reconstruct": [0, 1, 16, 18, 19, 23, 24, 26, 33, 41, 45, 58, 69, 72, 99], "task": [0, 1, 45, 68, 70, 71, 80, 98, 99], "neutrino": [0, 1, 48, 58, 75, 99], "telescop": [0, 1, 99], "us": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 15, 20, 25, 27, 28, 32, 33, 35, 36, 37, 38, 40, 41, 44, 45, 48, 49, 51, 56, 57, 58, 62, 63, 64, 67, 69, 70, 71, 72, 73, 75, 78, 79, 80, 82, 83, 84, 85, 86, 88, 89, 90, 91, 94, 95, 98, 99, 100], "graph": [0, 1, 6, 8, 10, 12, 13, 44, 45, 48, 49, 51, 61, 62, 63, 65, 66, 73, 79, 81, 98, 99], "neural": [0, 1, 99], "network": [0, 1, 55, 99], "gnn": [0, 1, 33, 45, 55, 56, 57, 58, 63, 68, 99, 100], "make": [0, 5, 82, 88, 89, 98, 99, 100], "fast": [0, 99, 100], "easi": [0, 99], "train": [0, 1, 7, 13, 40, 41, 44, 63, 68, 78, 79, 80, 81, 82, 84, 88, 89, 91, 97, 99, 100], "complex": [0, 45, 99], "model": [0, 1, 13, 41, 44, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 68, 69, 70, 71, 72, 73, 76, 77, 78, 80, 81, 84, 86, 88, 89, 91, 97, 99, 100], "can": [0, 1, 8, 10, 12, 13, 15, 17, 20, 38, 44, 49, 63, 75, 76, 82, 84, 86, 88, 89, 98, 99, 100], "event": [0, 1, 8, 10, 12, 13, 22, 36, 38, 40, 44, 49, 63, 70, 71, 72, 73, 75, 80, 82, 88, 99], "state": [0, 99], "art": [0, 99], "arbitrari": [0, 99], "detector": [0, 1, 25, 45, 52, 53, 63, 64, 66, 68, 99], "configur": [0, 1, 8, 45, 67, 68, 75, 83, 85, 86, 88, 89, 91, 95, 99], "infer": [0, 1, 33, 41, 44, 68, 70, 71, 72, 99, 100], "time": [0, 4, 36, 46, 49, 71, 95, 99, 100], "ar": [0, 1, 4, 5, 8, 10, 12, 13, 17, 30, 32, 35, 38, 40, 44, 49, 56, 58, 60, 61, 62, 63, 64, 65, 70, 75, 80, 82, 88, 89, 98, 99, 100], "order": [0, 28, 46, 73, 99], "magnitud": [0, 99], "faster": [0, 99], "than": [0, 6, 70, 71, 72, 81, 95, 99], "tradit": [0, 99], "techniqu": [0, 99], "common": [0, 1, 80, 86, 91, 92, 94, 99], "ml": [0, 1, 99], "develop": [0, 1, 98, 99, 100], "physicist": [0, 1, 99], "wish": [0, 98, 99], "tool": [0, 1, 99], "research": [0, 99], "By": [0, 38, 70, 71, 72, 99], "unit": [0, 5, 94, 98, 99], "both": [0, 17, 70, 71, 72, 76, 99], "group": [0, 5, 32, 35, 49, 99], "increas": [0, 78, 99], "longev": [0, 99], "usabl": [0, 99], "individu": [0, 5, 8, 10, 12, 13, 49, 56, 73, 99], "code": [0, 25, 36, 63, 88, 89, 99], "contribut": [0, 99, 100], "from": [0, 1, 6, 8, 10, 12, 13, 14, 15, 17, 19, 20, 22, 28, 29, 30, 33, 35, 38, 44, 49, 58, 62, 63, 66, 67, 70, 71, 72, 73, 76, 78, 79, 80, 86, 87, 88, 89, 91, 95, 98, 99, 100], "build": [0, 1, 45, 51, 62, 66, 67, 86, 88, 89, 99], "gener": [0, 5, 8, 10, 12, 13, 17, 44, 60, 61, 65, 70, 80, 99], "reusabl": [0, 99], "softwar": [0, 80, 99], "packag": [0, 1, 39, 90, 93, 94, 98, 99, 100], "base": [0, 4, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 33, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 91, 94, 95, 99], "engin": [0, 99], "best": [0, 98, 99], "practic": [0, 98, 99], "lower": [0, 76, 99], "technic": [0, 99], "threshold": [0, 44, 62, 99], "most": [0, 1, 40, 99, 100], "scientif": [0, 1, 99], "problem": [0, 62, 98, 99], "The": [0, 5, 8, 10, 12, 28, 30, 33, 35, 36, 44, 46, 48, 49, 56, 58, 62, 63, 70, 71, 72, 73, 75, 76, 78, 79, 80, 99], "improv": [0, 1, 84, 99], "classif": [0, 1, 45, 69, 72, 80, 99], "yield": [0, 56, 75, 80, 99], "veri": [0, 40, 99], "accur": [0, 99], "e": [0, 1, 5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 40, 44, 46, 48, 49, 51, 52, 53, 55, 59, 62, 63, 66, 67, 68, 70, 71, 72, 73, 78, 79, 80, 82, 86, 95, 98, 99, 100], "g": [0, 1, 5, 8, 10, 12, 13, 25, 28, 30, 32, 33, 35, 36, 40, 44, 49, 63, 70, 71, 72, 73, 82, 95, 98, 99, 100], "low": [0, 99], "energi": [0, 4, 33, 70, 71, 72, 82, 99], "observ": [0, 99], "icecub": [0, 1, 16, 29, 30, 45, 48, 50, 58, 94, 99, 100], "here": [0, 98, 99, 100], "implement": [0, 1, 5, 15, 31, 32, 34, 35, 48, 55, 56, 57, 58, 62, 80, 98, 99], "wa": [0, 99], "appli": [0, 8, 10, 12, 13, 15, 49, 55, 56, 57, 58, 59, 68, 90, 99], "oscil": [0, 74, 99], "lead": [0, 99], "signific": [0, 99], "angular": [0, 99], "rang": [0, 70, 71, 72, 99], "relev": [0, 1, 30, 39, 93, 98, 99], "studi": [0, 99], "furthermor": [0, 99], "shown": [0, 99], "could": [0, 98, 99], "muon": [0, 19, 99], "v": [0, 99], "therebi": [0, 1, 88, 89, 99], "effici": [0, 99], "puriti": [0, 99], "sampl": [0, 40, 99], "analysi": [0, 33, 99, 100], "similarli": [0, 30, 99], "ha": [0, 5, 30, 32, 35, 36, 44, 55, 80, 93, 99, 100], "great": [0, 99], "point": [0, 24, 79, 80, 99], "analys": [0, 41, 74, 99], "final": [0, 49, 78, 88, 99], "millisecond": [0, 99], "allow": [0, 41, 45, 49, 78, 86, 91, 99, 100], "whole": [0, 99], "new": [0, 1, 35, 48, 86, 91, 98, 99], "type": [0, 5, 6, 8, 10, 12, 13, 14, 15, 27, 28, 29, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 72, 73, 75, 76, 78, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96, 98, 99], "cosmic": [0, 99], "alert": [0, 99], "which": [0, 8, 10, 12, 13, 15, 16, 25, 29, 33, 40, 44, 46, 49, 56, 67, 70, 75, 80, 84, 99, 100], "were": [0, 99], "previous": [0, 99], "unfeas": [0, 99], "possibl": [0, 28, 98, 99], "identifi": [0, 5, 8, 10, 12, 13, 25, 88, 89, 99], "10": [0, 33, 84, 99], "tev": [0, 99], "monitor": [0, 99], "rate": [0, 78, 99], "direct": [0, 58, 70, 71, 72, 77, 79, 99], "real": [0, 99], "thi": [0, 3, 5, 8, 10, 12, 13, 15, 17, 30, 32, 35, 36, 39, 44, 45, 49, 56, 66, 68, 70, 71, 72, 73, 75, 76, 78, 80, 82, 86, 88, 89, 91, 95, 98, 99, 100], "enabl": [0, 3, 99], "first": [0, 78, 86, 91, 98, 99], "ever": [0, 99], "despit": [0, 99], "larg": [0, 80, 99], "background": [0, 99], "origin": [0, 75, 99], "compris": [0, 99], "number": [0, 5, 8, 10, 12, 13, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 70, 71, 72, 78, 84, 99], "modul": [0, 3, 8, 30, 33, 41, 44, 45, 48, 50, 54, 60, 61, 63, 64, 65, 67, 69, 74, 77, 83, 85, 88, 89, 90, 91, 94, 99], "necessari": [0, 28, 98, 99], "workflow": [0, 99], "ingest": [0, 1, 3, 50, 99], "raw": [0, 66, 99], "data": [0, 1, 4, 5, 6, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 46, 48, 49, 50, 51, 52, 55, 56, 57, 58, 59, 62, 63, 64, 67, 68, 70, 71, 72, 73, 79, 81, 84, 86, 88, 91, 94, 97, 99, 100], "domain": [0, 1, 3, 41, 99], "specif": [0, 1, 3, 5, 8, 10, 12, 16, 30, 31, 32, 34, 35, 36, 41, 46, 49, 50, 51, 52, 53, 54, 59, 62, 63, 66, 68, 69, 70, 71, 72, 80, 98, 99, 100], "format": [0, 1, 3, 5, 8, 28, 32, 35, 76, 88, 98, 99, 100], "deploi": [0, 1, 41, 44, 99], "chain": [0, 1, 41, 45, 68, 99, 100], "illustr": [0, 98, 99], "figur": [0, 76, 99], "level": [0, 8, 10, 12, 13, 25, 36, 46, 49, 62, 67, 95, 99, 100], "overview": [0, 99], "typic": [0, 28, 99], "convert": [0, 1, 3, 5, 28, 32, 35, 38, 99, 100], "industri": [0, 3, 99], "standard": [0, 3, 4, 5, 13, 32, 35, 40, 52, 53, 63, 66, 68, 84, 98, 99], "intermedi": [0, 1, 3, 5, 8, 32, 35, 55, 99, 100], "file": [0, 1, 3, 5, 8, 10, 12, 13, 15, 28, 32, 35, 38, 39, 44, 63, 67, 75, 78, 80, 84, 85, 86, 87, 88, 89, 93, 95, 99, 100], "read": [0, 3, 8, 10, 12, 13, 28, 51, 56, 68, 69, 99, 100], "simpl": [0, 45, 99], "physic": [0, 1, 15, 29, 30, 41, 44, 45, 69, 70, 71, 72, 99], "orient": [0, 45, 99], "compon": [0, 1, 45, 48, 49, 68, 99], "manag": [0, 15, 77, 99], "experi": [0, 1, 77, 99], "log": [0, 1, 71, 77, 78, 80, 83, 99, 100], "deploy": [0, 1, 42, 44, 63, 97, 99], "modular": [0, 45, 99], "subclass": [0, 45, 99], "torch": [0, 8, 10, 12, 13, 45, 48, 63, 64, 67, 68, 94, 99, 100], "nn": [0, 45, 48, 62, 64, 99], "mean": [0, 5, 8, 10, 12, 13, 32, 35, 45, 56, 58, 80, 89, 99], "onli": [0, 1, 8, 10, 12, 13, 45, 49, 70, 71, 72, 75, 82, 89, 94, 99, 100], "need": [0, 28, 45, 67, 80, 99, 100], "import": [0, 1, 36, 45, 83, 99], "few": [0, 45, 98, 99], "exist": [0, 8, 10, 12, 13, 33, 35, 36, 45, 79, 88, 99], "purpos": [0, 45, 80, 99], "built": [0, 45, 99], "them": [0, 1, 28, 45, 56, 70, 71, 72, 75, 99, 100], "togeth": [0, 45, 62, 68, 99], "form": [0, 45, 70, 86, 91, 99], "complet": [0, 45, 68, 99], "extend": [0, 1, 99], "suit": [0, 99], "through": [0, 80, 99], "layer": [0, 45, 47, 49, 55, 56, 57, 58, 70, 71, 72, 99], "connect": [0, 62, 63, 66, 80, 99], "etc": [0, 80, 95, 99], "optimis": [0, 1, 99], "differ": [0, 8, 10, 12, 13, 15, 64, 68, 98, 99, 100], "track": [0, 15, 19, 71, 98, 99], "These": [0, 63, 98, 99], "prepar": [0, 80, 99], "satisfi": [0, 99], "o": [0, 70, 71, 72, 99], "load": [0, 6, 8, 39, 67, 86, 88, 99], "requir": [0, 21, 36, 70, 80, 88, 89, 91, 99, 100], "when": [0, 5, 8, 10, 12, 13, 28, 32, 35, 36, 44, 48, 56, 58, 79, 95, 98, 99, 100], "batch": [0, 6, 33, 46, 48, 49, 68, 73, 81, 84, 99], "do": [0, 44, 80, 88, 89, 98, 99, 100], "predict": [0, 20, 24, 26, 33, 44, 55, 67, 68, 70, 71, 72, 80, 81, 99], "either": [0, 8, 10, 12, 80, 99, 100], "contain": [0, 5, 8, 10, 12, 13, 28, 29, 32, 33, 35, 44, 56, 60, 61, 63, 64, 65, 67, 70, 71, 72, 80, 82, 84, 99, 100], "imag": [0, 1, 98, 99, 100], "portabl": [0, 99], "depend": [0, 99, 100], "free": [0, 80, 99], "split": [0, 46, 99], "up": [0, 5, 32, 35, 44, 98, 99, 100], "interfac": [0, 74, 99, 100], "block": [0, 1, 99], "pre": [0, 13, 51, 63, 79, 98, 99], "directli": [0, 15, 99], "while": [0, 17, 78, 99], "continu": [0, 80, 99], "expand": [0, 99], "": [0, 5, 6, 8, 10, 12, 13, 15, 28, 35, 38, 55, 56, 68, 70, 71, 72, 73, 78, 82, 84, 88, 89, 95, 96, 99, 100], "capabl": [0, 99], "project": [0, 98, 99], "receiv": [0, 99], "fund": [0, 99], "european": [0, 99], "union": [0, 6, 8, 10, 12, 13, 17, 28, 30, 44, 46, 48, 49, 56, 67, 68, 70, 71, 72, 88, 91, 93, 99], "horizon": [0, 99], "2020": [0, 99], "innov": [0, 99], "programm": [0, 99], "under": [0, 13, 99], "mari": [0, 99], "sk\u0142odowska": [0, 99], "curi": [0, 99], "grant": [0, 80, 99], "agreement": [0, 98, 99], "No": [0, 99], "890778": [0, 99], "work": [0, 4, 29, 98, 99, 100], "rasmu": [0, 57, 99], "\u00f8rs\u00f8e": [0, 99], "partli": [0, 99], "punch4nfdi": [0, 99], "consortium": [0, 99], "support": [0, 30, 98, 99, 100], "dfg": [0, 99], "nfdi": [0, 99], "39": [0, 99, 100], "1": [0, 5, 8, 28, 32, 35, 40, 46, 49, 56, 58, 62, 64, 70, 71, 72, 73, 78, 80, 82, 88, 99, 100], "germani": [0, 99], "conveni": [1, 98, 100], "collabor": 1, "solv": [1, 98], "It": [1, 28, 36, 44, 98], "leverag": 1, "advanc": [1, 49], "machin": [1, 100], "learn": [1, 44, 78, 100], "without": [1, 62, 66, 75, 80, 100], "have": [1, 5, 17, 32, 35, 36, 40, 49, 63, 70, 71, 72, 98, 100], "expert": 1, "themselv": [1, 88, 89], "acceler": 1, "area": 1, "phyic": 1, "design": 1, "principl": 1, "all": [1, 5, 8, 10, 12, 13, 15, 17, 32, 35, 36, 44, 48, 49, 51, 56, 59, 63, 67, 72, 80, 86, 87, 88, 89, 90, 91, 95, 98, 100], "streamlin": 1, "process": [1, 5, 13, 15, 44, 51, 56, 98, 100], "transform": [1, 49, 70, 71, 72, 82], "extens": [1, 93], "basic": 1, "across": [1, 2, 8, 10, 12, 13, 30, 37, 49, 68, 80, 83, 84, 85, 95], "variou": 1, "easili": 1, "architectur": [1, 55, 56, 57, 58, 68], "main": [1, 54, 63, 68, 98, 100], "featur": [1, 3, 4, 5, 8, 10, 12, 13, 16, 33, 44, 48, 49, 51, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 70, 73, 81, 88, 98], "i3": [1, 5, 15, 29, 30, 32, 35, 39, 44, 93, 100], "more": [1, 8, 36, 39, 86, 88, 89, 91, 95], "index": [1, 5, 8, 10, 12, 30, 36, 49, 78], "sqlite": [1, 3, 7, 12, 13, 33, 35, 36, 38, 100], "suitabl": 1, "plug": 1, "plai": 1, "abstract": [1, 5, 8, 51, 59, 63, 67, 72, 87], "awai": 1, "detail": [1, 100], "expos": 1, "physicst": 1, "what": [1, 63, 98], "i3modul": [1, 41, 44], "includ": [1, 13, 67, 68, 75, 80, 86, 98], "docker": 1, "run": [1, 38], "containeris": 1, "fashion": 1, "subpackag": [1, 3, 7, 14, 41, 45, 60, 83], "dataset": [1, 3, 6, 9, 10, 11, 12, 13, 19, 40, 63, 84, 88], "extractor": [1, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 35, 44], "parquet": [1, 3, 7, 10, 32, 38, 100], "util": [1, 3, 14, 28, 29, 30, 36, 38, 39, 40, 45, 77, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97], "constant": [1, 3, 97], "dataconvert": [1, 3, 32, 35], "dataload": [1, 3, 33, 63, 67, 68, 81, 91], "pipelin": [1, 3], "coarsen": [1, 45, 49], "standard_model": [1, 45], "pisa": [1, 21, 33, 75, 76, 94, 97, 100], "fit": [1, 67, 74, 76, 80, 82, 91], "plot": [1, 74], "callback": [1, 67, 77], "label": [1, 8, 19, 22, 55, 63, 68, 72, 76, 77, 81], "loss_funct": [1, 70, 71, 72, 77], "weight_fit": [1, 77], "config": [1, 6, 40, 75, 80, 83, 84, 86, 87, 88, 89, 90, 91], "argpars": [1, 83], "decor": [1, 5, 83, 94], "filesi": [1, 83], "math": [1, 83], "submodul": [1, 3, 7, 9, 11, 14, 27, 31, 34, 37, 42, 45, 47, 50, 54, 60, 61, 65, 69, 74, 77, 83, 85, 90], "global": [2, 4, 56, 58, 67], "i3extractor": [3, 5, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35], "i3featureextractor": [3, 4, 14, 35, 44], "i3genericextractor": [3, 14, 35], "i3hybridrecoextractor": [3, 14], "i3ntmuonlabelsextractor": [3, 14], "i3particleextractor": [3, 14], "i3pisaextractor": [3, 14], "i3quesoextractor": [3, 14], "i3retroextractor": [3, 14], "i3splinempeextractor": [3, 14], "i3truthextractor": [3, 4, 14], "i3tumextractor": [3, 14], "parquet_dataconvert": [3, 31], "sqlite_dataconvert": [3, 34], "sqlite_util": [3, 34], "parquet_to_sqlit": [3, 37], "random": [3, 8, 10, 12, 13, 37, 40, 88], "string_selection_resolv": [3, 37], "truth": [3, 4, 8, 10, 12, 13, 16, 25, 33, 36, 63, 81, 82, 88], "fileset": [3, 5], "init_global_index": [3, 5], "cache_output_fil": [3, 5], "collate_fn": [3, 6, 77, 81], "do_shuffl": [3, 6], "insqlitepipelin": [3, 33], "class": [4, 5, 6, 7, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 30, 31, 32, 33, 34, 35, 38, 40, 44, 46, 48, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 98], "object": [4, 5, 8, 10, 12, 13, 15, 17, 28, 30, 44, 49, 51, 63, 70, 71, 72, 75, 84, 95], "namespac": [4, 67], "name": [4, 5, 6, 8, 10, 12, 13, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 30, 32, 33, 35, 36, 38, 44, 62, 63, 64, 66, 67, 70, 71, 72, 75, 79, 82, 84, 86, 88, 89, 90, 91, 95, 98, 100], "icecube86": [4, 50, 52], "dom_x": [4, 44, 46], "dom_i": [4, 44, 46], "dom_z": [4, 44, 46], "dom_tim": 4, "charg": [4, 44, 80], "rde": [4, 46], "pmt_area": [4, 46], "deepcor": [4, 16, 52], "upgrad": [4, 16, 52, 100], "string": [4, 5, 8, 10, 12, 13, 28, 32, 35, 40, 49, 86], "pmt_number": 4, "dom_numb": 4, "pmt_dir_x": 4, "pmt_dir_i": 4, "pmt_dir_z": 4, "dom_typ": 4, "prometheu": [4, 45, 50], "sensor_pos_x": 4, "sensor_pos_i": 4, "sensor_pos_z": 4, "t": [4, 30, 36, 76, 78, 80, 100], "kaggl": [4, 48, 52, 58], "x": [4, 5, 25, 32, 35, 48, 49, 66, 67, 72, 73, 76, 80, 82], "y": [4, 25, 73, 76, 100], "z": [4, 5, 25, 32, 35, 73, 100], "auxiliari": 4, "energy_track": 4, "position_x": 4, "position_i": 4, "position_z": 4, "azimuth": [4, 71, 79], "zenith": [4, 71, 79], "pid": [4, 40, 88], "elast": 4, "sim_typ": 4, "interaction_typ": 4, "interaction_tim": [4, 71], "inelast": [4, 71], "stopped_muon": 4, "injection_energi": 4, "injection_typ": 4, "injection_interaction_typ": 4, "injection_zenith": 4, "injection_azimuth": 4, "injection_bjorkenx": 4, "injection_bjorkeni": 4, "injection_position_x": 4, "injection_position_i": 4, "injection_position_z": 4, "injection_column_depth": 4, "primary_lepton_1_typ": 4, "primary_hadron_1_typ": 4, "primary_lepton_1_position_x": 4, "primary_lepton_1_position_i": 4, "primary_lepton_1_position_z": 4, "primary_hadron_1_position_x": 4, "primary_hadron_1_position_i": 4, "primary_hadron_1_position_z": 4, "primary_lepton_1_direction_theta": 4, "primary_lepton_1_direction_phi": 4, "primary_hadron_1_direction_theta": 4, "primary_hadron_1_direction_phi": 4, "primary_lepton_1_energi": 4, "primary_hadron_1_energi": 4, "total_energi": 4, "i3_fil": [5, 15], "str": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 52, 53, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 79, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 95], "gcd_file": [5, 15, 44], "paramet": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 55, 56, 57, 58, 59, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96], "output_fil": [5, 32, 35], "global_index": 5, "avail": [5, 17, 33, 94], "pool": [5, 45, 46, 47, 56, 58], "worker": [5, 32, 33, 35, 39, 84, 95], "return": [5, 6, 8, 10, 12, 13, 15, 28, 29, 30, 32, 35, 36, 38, 39, 40, 46, 48, 49, 51, 52, 53, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 70, 72, 73, 75, 76, 78, 79, 80, 81, 82, 84, 86, 87, 88, 89, 90, 93, 94, 95, 96], "none": [5, 6, 8, 10, 12, 13, 15, 17, 25, 29, 30, 32, 33, 35, 36, 38, 40, 44, 46, 48, 49, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 75, 78, 80, 81, 82, 84, 86, 87, 88, 90, 93, 95], "synchron": 5, "list": [5, 6, 8, 10, 12, 13, 15, 17, 25, 28, 30, 32, 33, 35, 36, 38, 39, 40, 44, 46, 48, 49, 51, 56, 58, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 76, 78, 80, 81, 82, 88, 90, 91, 93, 95], "process_method": 5, "cach": 5, "output": [5, 32, 35, 38, 55, 56, 57, 59, 66, 67, 68, 75, 82, 88, 89, 100], "typevar": 5, "f": [5, 49], "bound": [5, 76], "callabl": [5, 6, 8, 30, 48, 49, 51, 52, 53, 63, 70, 71, 72, 81, 82, 86, 88, 89, 90, 94], "ani": [5, 6, 8, 10, 12, 28, 29, 30, 32, 35, 44, 48, 49, 56, 62, 63, 67, 68, 70, 72, 76, 80, 82, 84, 86, 87, 88, 89, 90, 91, 95, 100], "outdir": [5, 32, 33, 35, 38, 75], "gcd_rescu": [5, 32, 35, 93], "nb_files_to_batch": [5, 32, 35], "sequential_batch_pattern": [5, 32, 35], "input_file_batch_pattern": [5, 32, 35], "index_column": [5, 8, 10, 12, 13, 32, 35, 36, 40, 75, 81, 82, 88], "icetray_verbos": [5, 32, 35], "abc": [5, 8, 15, 33, 67, 79, 82, 87], "logger": [5, 8, 15, 33, 38, 40, 62, 67, 79, 82, 83, 95, 100], "construct": [5, 6, 8, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 32, 35, 38, 40, 46, 47, 48, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 70, 71, 72, 75, 78, 79, 80, 81, 82, 84, 87, 88, 89, 95], "regular": [5, 30, 32, 35], "express": [5, 32, 35, 67, 80], "accord": [5, 13, 32, 35, 46, 49, 62], "match": [5, 32, 35, 82, 93, 96], "certain": [5, 32, 35, 38, 75], "pattern": [5, 32, 35], "wildcard": [5, 32, 35], "same": [5, 30, 32, 35, 36, 46, 49, 70, 73, 78, 90, 95], "input": [5, 8, 10, 12, 13, 32, 33, 35, 44, 52, 55, 56, 57, 58, 59, 63, 66, 70, 72, 73, 86, 91], "replac": [5, 32, 35, 86, 88, 89, 91], "period": [5, 32, 35], "special": [5, 17, 32, 35, 44, 73], "interpret": [5, 32, 35, 70], "liter": [5, 32, 35], "charact": [5, 32, 35], "regex": [5, 32, 35], "For": [5, 30, 32, 35, 78], "instanc": [5, 8, 15, 25, 30, 32, 35, 44, 63, 67, 75, 79, 81, 87, 100], "A": [5, 8, 32, 33, 35, 44, 49, 64, 73, 75, 80, 82, 100], "_": [5, 32, 35], "0": [5, 8, 10, 12, 32, 35, 40, 44, 46, 49, 55, 56, 58, 62, 64, 73, 75, 76, 80, 88], "9": [5, 32, 35], "5": [5, 8, 10, 12, 32, 35, 40, 84, 100], "zst": [5, 32, 35], "find": [5, 32, 35, 93], "whose": [5, 32, 35, 44], "one": [5, 8, 32, 35, 36, 44, 49, 67, 88, 89, 93, 98, 100], "capit": [5, 32, 35], "letter": [5, 32, 35], "follow": [5, 32, 35, 56, 68, 80, 82, 98, 100], "underscor": [5, 32, 35], "five": [5, 32, 35], "upgrade_genie_step4_141020_a_000000": [5, 32, 35], "upgrade_genie_step4_141020_a_000001": [5, 32, 35], "upgrade_genie_step4_141020_a_000008": [5, 32, 35], "upgrade_genie_step4_141020_a_000009": [5, 32, 35], "would": [5, 32, 35, 98], "upgrade_genie_step4_141020_a_00000x": [5, 32, 35], "suffix": [5, 32, 35], "upgrade_genie_step4_141020_a_000010": [5, 32, 35], "separ": [5, 28, 32, 35, 78, 100], "upgrade_genie_step4_141020_a_00001x": [5, 32, 35], "int": [5, 6, 8, 10, 12, 13, 19, 22, 32, 33, 35, 40, 48, 49, 55, 56, 57, 58, 59, 62, 64, 66, 67, 68, 70, 71, 72, 73, 75, 78, 80, 81, 82, 84, 88, 91, 95], "properti": [5, 8, 15, 20, 30, 49, 59, 66, 68, 72, 79, 87, 95], "file_suffix": [5, 32, 35], "execut": [5, 36], "method": [5, 8, 10, 12, 15, 27, 28, 29, 30, 32, 35, 44, 48, 49, 71, 80, 82], "set": [5, 17, 70, 71, 72, 98], "inherit": [5, 15, 30, 51, 66, 80, 95], "path": [5, 8, 10, 12, 13, 36, 39, 44, 63, 67, 75, 76, 84, 86, 87, 88, 93, 100], "correspond": [5, 8, 10, 12, 13, 28, 30, 35, 39, 56, 63, 82, 93, 100], "gcd": [5, 15, 29, 39, 44, 93], "save_data": [5, 32, 35], "save": [5, 15, 28, 32, 35, 36, 67, 75, 80, 81, 82, 86, 87, 88, 89, 100], "ordereddict": [5, 32, 35], "extract": [5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 35, 38, 39, 44, 70, 71, 72], "merge_fil": [5, 32, 35], "input_fil": [5, 32, 35], "merg": [5, 32, 35, 80, 100], "result": [5, 32, 35, 49, 78, 80, 81, 90, 100], "option": [5, 8, 10, 12, 13, 25, 32, 33, 35, 44, 48, 49, 56, 58, 63, 64, 67, 70, 71, 72, 75, 76, 80, 82, 83, 84, 86, 88, 93, 100], "default": [5, 8, 10, 12, 13, 17, 25, 28, 32, 33, 35, 36, 38, 44, 48, 49, 55, 56, 57, 58, 62, 63, 64, 66, 67, 70, 71, 72, 75, 76, 78, 79, 80, 82, 84, 86, 88, 93], "current": [5, 32, 35, 40, 78, 98, 100], "rais": [5, 8, 17, 32, 67, 86, 91], "notimplementederror": [5, 32], "If": [5, 8, 17, 32, 33, 35, 67, 70, 71, 72, 75, 78, 82, 98, 100], "been": [5, 32, 44, 80, 98], "backend": [5, 9, 11, 32, 35], "question": 5, "get_map_funct": 5, "nb_file": 5, "map": [5, 8, 10, 12, 13, 16, 17, 35, 36, 44, 52, 53, 86, 88, 89, 91], "pure": [5, 14, 15, 17, 30], "multiprocess": [5, 100], "tupl": [5, 8, 10, 12, 29, 30, 48, 56, 58, 70, 71, 72, 73, 75, 76, 81, 84], "remov": [6, 81, 84], "less": [6, 81], "two": [6, 56, 75, 78, 80, 81], "dom": [6, 8, 10, 12, 13, 46, 49, 81], "hit": [6, 81], "should": [6, 8, 10, 12, 13, 15, 28, 40, 48, 49, 80, 81, 86, 88, 89, 91, 98, 100], "occur": [6, 81], "product": [6, 81], "selection_nam": 6, "check": [6, 29, 30, 35, 36, 84, 93, 94, 98, 100], "whether": [6, 29, 30, 35, 36, 56, 67, 80, 90, 93, 94], "shuffl": [6, 39, 81], "select": [6, 8, 10, 12, 13, 22, 40, 81, 82, 88, 98], "bool": [6, 29, 30, 35, 36, 40, 44, 46, 56, 67, 68, 75, 78, 80, 81, 82, 84, 90, 93, 94, 95], "batch_siz": [6, 33, 73, 81], "num_work": [6, 81], "persistent_work": [6, 81], "prefetch_factor": 6, "kwarg": [6, 48, 62, 67, 70, 72, 80, 82, 86, 95], "t_co": 6, "classmethod": [6, 8, 67, 80, 86, 87], "from_dataset_config": 6, "datasetconfig": [6, 8, 40, 85, 88], "dict": [6, 8, 13, 17, 28, 30, 33, 35, 51, 52, 53, 63, 67, 68, 75, 76, 78, 80, 81, 84, 86, 88, 89, 90, 91], "parquet_dataset": [7, 9], "sqlite_dataset": [7, 11], "sqlite_dataset_perturb": [7, 11], "columnmissingexcept": [7, 8], "load_modul": [7, 8, 67], "parse_graph_definit": [7, 8], "ensembledataset": [7, 8, 88], "except": 8, "indic": [8, 40, 49, 78, 84, 98], "miss": 8, "column": [8, 10, 12, 13, 36, 44, 62, 63, 64, 66, 67, 68, 70, 71, 72, 73, 75, 82], "class_nam": [8, 62, 67, 89, 95], "cfg": 8, "graphdefinit": [8, 10, 12, 44, 60, 61, 63, 64, 65, 68, 81, 98], "graph_definit": [8, 10, 12, 44, 45, 60, 68, 81, 88], "pulsemap": [8, 10, 12, 13, 16, 35, 44, 81, 88], "node_truth": [8, 10, 12, 13, 81, 88], "truth_tabl": [8, 10, 12, 13, 75, 81, 82, 88], "node_truth_t": [8, 10, 12, 13, 81, 88], "string_select": [8, 10, 12, 13, 81, 88], "dtype": [8, 10, 12, 13, 63, 64, 96], "loss_weight_t": [8, 10, 12, 13, 81, 88], "loss_weight_column": [8, 10, 12, 13, 63, 81, 88], "loss_weight_default_valu": [8, 10, 12, 13, 63, 88], "seed": [8, 10, 12, 13, 40, 81, 88], "puls": [8, 10, 12, 13, 16, 17, 29, 30, 35, 36, 44, 46, 49, 66, 73], "seri": [8, 10, 12, 13, 16, 17, 29, 30, 36, 44], "node": [8, 10, 12, 13, 45, 46, 49, 55, 56, 58, 60, 61, 62, 63, 64, 70, 71, 72, 73], "multipl": [8, 10, 12, 13, 15, 78, 88, 95], "store": [8, 10, 12, 13, 15, 33, 36, 75, 79], "ad": [8, 10, 12, 13, 16, 56, 63, 75], "attribut": [8, 10, 12, 13, 46, 70, 71, 72], "event_no": [8, 10, 12, 13, 36, 40, 82, 88], "uniqu": [8, 10, 12, 13, 36, 38, 88], "indici": [8, 10, 12, 13, 29, 40, 80], "tabl": [8, 10, 12, 13, 15, 33, 35, 36, 63, 75, 82], "inform": [8, 10, 12, 13, 15, 17, 25, 76], "subset": [8, 10, 12, 13, 48, 56, 58], "given": [8, 10, 12, 13, 35, 49, 62, 70, 71, 72, 82, 84], "queri": [8, 10, 12, 36, 40], "pass": [8, 10, 12, 48, 55, 56, 57, 58, 59, 63, 67, 68, 70, 71, 72, 80, 82, 98], "float32": [8, 10, 12, 13, 63, 64], "tensor": [8, 10, 12, 13, 46, 48, 49, 51, 55, 56, 57, 58, 59, 66, 67, 68, 70, 71, 72, 73, 80, 96], "per": [8, 10, 12, 13, 17, 36, 49, 70, 71, 72, 80, 82], "loss": [8, 10, 12, 13, 63, 68, 70, 71, 72, 78, 80, 84], "weight": [8, 10, 12, 13, 44, 63, 70, 71, 72, 75, 80, 82, 89, 100], "also": [8, 10, 12, 13, 40, 88], "assign": [8, 10, 12, 13, 38, 46, 49, 98], "float": [8, 10, 12, 13, 44, 46, 55, 62, 63, 67, 75, 76, 78, 80, 81, 88], "note": [8, 10, 12, 13, 76, 89], "valu": [8, 10, 12, 13, 25, 28, 35, 36, 49, 63, 76, 79, 80, 84, 86], "specifi": [8, 10, 12, 13, 40, 46, 70, 71, 72, 76, 78, 100], "case": [8, 10, 12, 13, 17, 44, 49, 70, 71, 72, 100], "That": [8, 10, 12, 13, 56, 71, 79], "ignor": [8, 10, 12, 13, 30], "resolv": [8, 10, 12, 40], "10000": [8, 10, 12, 40], "20": [8, 10, 12, 40, 95], "defin": [8, 10, 12, 40, 44, 49, 60, 61, 62, 63, 65, 86, 88, 89, 91], "represent": [8, 10, 12, 30, 49, 64], "from_config": [8, 67, 87, 88, 89], "concaten": [8, 28, 56], "query_t": [8, 10, 12], "sequential_index": [8, 10, 12], "some": [8, 10, 12, 63], "out": [8, 56, 68, 69, 80, 95, 98, 100], "sequenti": 8, "len": 8, "self": [8, 63, 75, 86, 91], "_may_": 8, "_indic": 8, "entir": [8, 67], "impos": 8, "befor": [8, 56, 70, 71, 72, 78], "scalar": [8, 73, 80], "length": [8, 30, 78], "element": [8, 28, 30, 68, 73, 90], "present": [8, 84, 93, 94], "add_label": 8, "fn": [8, 30, 86, 90], "kei": [8, 17, 28, 29, 30, 35, 36, 46, 49, 79, 88, 89], "add": [8, 56, 84, 98, 100], "custom": [8, 63, 78], "concatdataset": 8, "singl": [8, 15, 49, 56, 79, 88, 89], "collect": [8, 14, 15, 27, 80, 96], "iter": 8, "parquetdataset": [9, 10], "pytorch": [10, 12, 13, 78, 100], "sqlitedataset": [11, 12, 13], "sqlitedatasetperturb": [11, 13], "databas": [12, 13, 33, 35, 36, 38, 75, 82, 100], "perturb": 13, "perturbation_dict": 13, "step": [13, 68, 78], "where": [13, 63, 64, 66, 79], "randomli": [13, 40, 89], "nois": [13, 16, 29, 44], "intend": [13, 100], "test": [13, 40, 70, 71, 72, 81, 88, 94, 98], "stabil": 13, "small": [13, 80], "chang": [13, 75, 80, 98], "dictionari": [13, 28, 29, 30, 33, 35, 63, 75, 76, 86, 88, 89, 91], "deviat": 13, "i3fram": [14, 15, 17, 29, 30, 44], "frame": [14, 15, 17, 27, 30, 35, 44], "i3extractorcollect": [14, 15], "i3featureextractoricecube86": [14, 16], "i3featureextractoricecubedeepcor": [14, 16], "i3featureextractoricecubeupgrad": [14, 16], "i3pulsenoisetruthflagicecubeupgrad": [14, 16], "i3galacticplanehybridrecoextractor": [14, 18], "i3ntmuonlabelextractor": [14, 19], "i3splinempeicextractor": [14, 24], "__call__": 15, "icetrai": [15, 29, 30, 44, 94], "keep": 15, "proven": 15, "set_fil": 15, "refer": [15, 88], "being": [15, 44, 70, 71, 72], "get": [15, 29, 78, 81, 100], "treat": 15, "86": [16, 52], "flag": [16, 44], "exclude_kei": 17, "dynam": [17, 48, 56, 57, 58], "pars": [17, 76, 83, 84, 85, 86, 91], "call": [17, 30, 35, 49, 75, 82, 95], "tri": [17, 30], "automat": [17, 80, 98], "cast": [17, 30], "done": [17, 49, 95, 98], "recurs": [17, 30, 90, 93], "each": [17, 28, 30, 36, 38, 39, 46, 49, 52, 53, 56, 58, 62, 63, 64, 66, 67, 70, 71, 72, 73, 75, 76, 78, 93], "look": [17, 100], "member": [17, 30, 88, 89, 95], "variabl": [17, 30, 56, 73, 82, 95], "signatur": [17, 30], "similar": [17, 30, 100], "handl": [17, 80, 84, 95], "hand": 17, "mc": [17, 35, 36], "tree": [17, 35], "trigger": 17, "exclud": [17, 38, 100], "valueerror": [17, 67], "hybrid": 18, "galatict": 18, "plane": [18, 80], "tum": [19, 26], "dnn": [19, 26], "padding_valu": [19, 22], "northeren": 19, "i3particl": 20, "other": [20, 36, 62, 80, 98], "algorithm": 20, "comparison": [20, 80], "quantiti": [21, 70, 71, 72, 73], "queso": 22, "retro": [23, 33], "splinemp": 24, "border": 25, "mctree": [25, 29], "ndarrai": [25, 63, 82], "arrai": [25, 28], "boundari": 25, "volum": 25, "coordin": [25, 73], "particl": [25, 36, 79], "start": [25, 98, 100], "stop": [25, 84], "within": [25, 46, 48, 49, 56, 62], "hard": 25, "i3mctre": 25, "flatten_nested_dictionari": [27, 28], "serialis": [27, 28], "transpose_list_of_dict": [27, 28], "frame_is_montecarlo": [27, 29], "frame_is_nois": [27, 29], "get_om_keys_and_pulseseri": [27, 29], "is_boost_enum": [27, 30], "is_boost_class": [27, 30], "is_icecube_class": [27, 30], "is_typ": [27, 30], "is_method": [27, 30], "break_cyclic_recurs": [27, 30], "get_member_vari": [27, 30], "cast_object_to_pure_python": [27, 30], "cast_pulse_series_to_pure_python": [27, 30], "manipul": [28, 60, 61, 65], "obj": [28, 30, 90], "parent_kei": 28, "flatten": 28, "nest": 28, "non": [28, 30, 35, 36, 80], "exampl": [28, 40, 46, 49, 80, 88, 89, 100], "d": [28, 63, 66, 98], "b": [28, 46, 49], "c": [28, 49, 80, 100], "2": [28, 49, 56, 58, 62, 64, 71, 73, 75, 76, 80, 88, 100], "a__b": 28, "applic": 28, "combin": [28, 88], "parent": 28, "__": [28, 30], "nester": 28, "json": [28, 88], "therefor": 28, "we": [28, 30, 40, 98, 100], "outer": 28, "abl": [28, 100], "de": 28, "transpos": 28, "mont": 29, "carlo": 29, "simul": [29, 44], "pulseseri": 29, "calibr": [29, 30], "gcd_dict": [29, 30], "p": [29, 35, 80], "om": [29, 30], "dataclass": 29, "i3calibr": 29, "indicesfor": 29, "boost": 30, "enum": 30, "ensur": [30, 39, 80, 95, 98, 100], "isn": 30, "return_discard": 30, "valid": [30, 40, 68, 70, 71, 72, 80, 84, 86, 91], "mangl": 30, "take": [30, 35, 49, 98], "mainli": 30, "cannot": [30, 86, 91], "trivial": [30, 72], "doe": [30, 89], "try": 30, "equival": 30, "its": 30, "like": [30, 49, 73, 80, 96, 98], "otherwis": [30, 80], "itself": [30, 70, 71, 72], "deem": 30, "wai": [30, 40, 98, 100], "optic": 30, "found": [30, 80], "parquetdataconvert": [31, 32], "module_dict": 33, "devic": 33, "retro_table_nam": 33, "n_worker": [33, 75], "pipeline_nam": 33, "creat": [33, 35, 36, 63, 86, 87, 91, 98, 100], "initialis": [33, 89], "gnn_module_for_energy_regress": 33, "modulelist": 33, "comput": [33, 68, 70, 71, 72, 73, 80], "directori": [33, 38, 75, 93], "100": [33, 100], "size": [33, 48, 49, 56, 57, 58, 84], "alreadi": [33, 36, 100], "error": [33, 80, 95, 98], "prompt": 33, "avoid": [33, 95, 98], "overwrit": [33, 78], "sqlitedataconvert": [34, 35, 100], "construct_datafram": [34, 35], "is_pulse_map": [34, 35], "is_mc_tre": [34, 35], "database_exist": [34, 36], "database_table_exist": [34, 36], "run_sql_cod": [34, 36], "save_to_sql": [34, 36], "attach_index": [34, 36], "create_t": [34, 36], "create_table_and_save_to_sql": [34, 36], "db": [35, 81], "max_table_s": 35, "maximum": [35, 49, 70, 71, 72, 84], "row": [35, 36], "exce": 35, "limit": [35, 80], "any_pulsemap_is_non_empti": 35, "data_dict": 35, "empti": [35, 44], "retriev": 35, "splitinicepuls": 35, "least": [35, 98, 100], "true": [35, 36, 44, 75, 78, 80, 82, 88, 89, 91], "becaus": [35, 39], "instead": [35, 80, 86, 91], "alwai": 35, "panda": [35, 40, 82], "datafram": [35, 36, 40, 67, 68, 75, 81, 82], "table_nam": [35, 36], "database_path": [36, 75, 82], "df": 36, "must": [36, 46, 78, 82, 98], "attach": 36, "default_typ": 36, "null": 36, "integer_primary_kei": 36, "NOT": [36, 80], "integ": [36, 56, 57, 80], "primari": 36, "Such": 36, "appropri": [36, 70, 71, 72], "expect": [36, 40, 44, 66], "doesn": 36, "parquettosqliteconvert": [37, 38], "pairwise_shuffl": [37, 39], "stringselectionresolv": [37, 40], "parquet_path": 38, "mc_truth_tabl": 38, "excluded_field": 38, "id": 38, "everi": [38, 100], "field": [38, 76, 79, 86, 88, 89, 91], "One": [38, 76], "choos": 38, "argument": [38, 82, 84, 86, 88, 89, 91], "exclude_field": 38, "database_nam": 38, "convers": [38, 100], "rng": 39, "relat": [39, 93], "i3_list": [39, 93], "gcd_list": [39, 93], "correpond": 39, "handi": 39, "even": 39, "files_list": 39, "gcd_shuffl": 39, "i3_shuffl": 39, "use_cach": 40, "flexibl": 40, "below": [40, 76, 82, 98, 100], "show": [40, 78], "involv": 40, "cover": 40, "yml": [40, 84, 88, 89], "50000": [40, 88], "ab": [40, 80, 88], "12": [40, 88], "14": [40, 88], "16": [40, 88], "13": [40, 100], "compat": 40, "syntax": [40, 80], "mai": [40, 66, 100], "fix": 40, "graphnet_modul": [41, 42], "graphneti3modul": [42, 44], "i3inferencemodul": [42, 44], "i3pulsecleanermodul": [42, 44], "pulsemap_extractor": 44, "produc": [44, 79, 82], "write": [44, 100], "constructor": 44, "knngraph": [44, 60, 64], "associ": [44, 63, 71, 80], "model_config": [44, 83, 85, 86, 88, 91], "state_dict": [44, 67], "model_nam": [44, 75], "prediction_column": [44, 67, 68, 81], "pulsmap": 44, "modelconfig": [44, 67, 85, 88, 89], "summar": 44, "Will": [44, 62], "help": [44, 84, 98], "entri": [44, 56, 76, 84], "dynedg": [44, 45, 54, 57, 58], "energy_reco": 44, "discard_empty_ev": 44, "clean": [44, 98, 100], "assum": [44, 51, 72, 73], "7": [44, 49, 75], "consid": [44, 100], "posit": [44, 49, 71], "signal": 44, "els": 44, "fals": [44, 56, 67, 75, 78, 80, 82, 88], "elimin": 44, "speed": 44, "especi": 44, "sinc": [44, 80], "further": 44, "calcul": [44, 62, 64, 68, 73, 79, 80], "convnet": [45, 54], "dynedge_jinst": [45, 54], "dynedge_kaggle_tito": [45, 54], "edg": [45, 48, 49, 56, 57, 58, 60, 63, 64, 65, 66, 73], "unbatch_edge_index": [45, 46], "attributecoarsen": [45, 46], "domcoarsen": [45, 46], "customdomcoarsen": [45, 46], "domandtimewindowcoarsen": [45, 46], "standardmodel": [45, 68], "calculate_xyzt_homophili": [45, 73], "calculate_distance_matrix": [45, 73], "knn_graph_batch": [45, 73], "oper": [46, 48, 54, 56], "cluster": [46, 48, 49, 56, 58], "local": [46, 84], "edge_index": [46, 48, 73], "vector": [46, 49, 80], "longtensor": [46, 49, 73], "mathbf": [46, 49], "ldot": [46, 49], "n": [46, 49, 80], "reduc": 46, "transfer_attribut": 46, "reduce_opt": 46, "avg": 46, "avg_pool": 46, "avg_pool_x": 46, "max": [46, 48, 56, 58, 80, 84], "max_pool": [46, 49], "max_pool_x": [46, 49], "min": [46, 49, 56, 58], "min_pool": [46, 47, 49], "min_pool_x": [46, 47, 49], "sum": [46, 49, 56, 58, 68], "sum_pool": [46, 47, 49], "sum_pool_x": [46, 47, 49], "forward": [46, 48, 51, 55, 56, 57, 58, 59, 62, 63, 66, 67, 68, 72, 80], "simplecoarsen": 46, "addit": [46, 48, 67, 68, 80, 82], "time_window": 46, "time_kei": 46, "window": 46, "dynedgeconv": [47, 48, 56], "edgeconvtito": [47, 48], "dyntran": [47, 48, 58], "sum_pool_and_distribut": [47, 49], "group_bi": [47, 49], "group_pulses_to_dom": [47, 49], "group_pulses_to_pmt": [47, 49], "std_pool_x": [47, 49], "std_pool": [47, 49], "aggr": 48, "nb_neighbor": 48, "features_subset": [48, 56, 58], "edgeconv": 48, "lightningmodul": [48, 67, 78, 95], "convolut": [48, 55, 56, 57, 58], "mlp": [48, 56], "aggreg": [48, 49], "8": [48, 49, 56, 64, 80, 98, 100], "neighbour": [48, 56, 58, 62, 64, 73], "after": [48, 56, 78, 84], "sequenc": 48, "slice": [48, 56, 58], "sparsetensor": 48, "messagepass": 48, "tito": [48, 58], "solut": [48, 58, 98], "deep": [48, 58], "competit": [48, 52, 58], "reset_paramet": 48, "reset": 48, "learnabl": [48, 54, 55, 56, 57, 58, 59], "messag": [48, 78, 95], "x_i": 48, "x_j": 48, "layer_s": 48, "n_head": 48, "dyntrans1": 48, "head": 48, "multiheadattent": 48, "just": [49, 100], "negat": 49, "cluster_index": 49, "distribut": [49, 56, 71, 80, 82], "ident": [49, 72], "pmt": 49, "f1": 49, "f2": 49, "6": [49, 76], "groupbi": 49, "3": [49, 55, 58, 71, 73, 75, 76, 80, 98, 100], "matrix": [49, 62, 73, 80], "mathbb": 49, "r": [49, 62, 100], "n_1": 49, "n_b": 49, "obtain": [49, 80], "wise": 49, "dens": 49, "fc": 49, "known": 49, "std": 49, "repres": [49, 63, 64, 66, 86, 88, 89], "averag": [49, 80], "torch_geometr": 49, "version": [49, 70, 71, 72, 78, 98, 100], "standardis": 50, "icecubekaggl": [50, 52], "icecubedeepcor": [50, 52], "icecubeupgrad": [50, 52], "ins": 51, "feature_map": [51, 52, 53], "node_featur": [51, 63], "node_feature_nam": [51, 63, 64, 66], "adjac": 51, "dimens": [52, 53, 55, 56, 58, 80], "prototyp": 53, "dynedgejinst": [54, 57], "dynedgetito": [54, 58], "author": [55, 57, 80], "martin": 55, "minh": 55, "nb_input": [55, 56, 57, 58, 59, 70, 71, 72], "nb_output": [55, 57, 59, 66, 70, 72], "nb_intermedi": 55, "dropout_ratio": 55, "128": [55, 56, 84], "fraction": 55, "drop": 55, "nb_neighbour": 56, "dynedge_layer_s": 56, "post_processing_layer_s": 56, "readout_layer_s": 56, "global_pooling_schem": [56, 58], "add_global_variables_after_pool": 56, "k": [56, 58, 62, 64, 73, 80], "nearest": [56, 58, 62, 64, 73], "latent": [56, 58, 70], "metric": [56, 58, 78], "dimenion": [56, 58], "multi": 56, "perceptron": 56, "256": 56, "336": 56, "hidden": [56, 57, 70, 72], "skip": 56, "post": 56, "_and_": 56, "As": 56, "last": [56, 70, 72, 78], "scheme": [56, 58], "altern": [56, 80, 98], "exact": [57, 80], "2209": 57, "03042": 57, "oerso": 57, "layer_size_scal": 57, "4": [57, 58, 71, 76], "scale": [57, 63, 70, 71, 72, 80], "ic": 58, "univers": 58, "south": 58, "pole": 58, "dyntrans_layer_s": 58, "core": 59, "edgedefinit": [60, 61, 62, 63, 65], "how": [60, 61, 65], "drawn": [60, 61, 64, 65], "between": [60, 61, 62, 65, 68, 73, 78, 80], "knnedg": [61, 62], "radialedg": [61, 62], "euclideanedg": [61, 62], "log_fold": [62, 67, 95], "_construct_edg": 62, "nb_nearest_neighbour": [62, 64], "definit": [62, 63, 64, 66, 67, 98], "space": [62, 82], "distanc": [62, 64, 73], "radiu": 62, "sphere": 62, "chosen": [62, 95], "centr": 62, "radial": 62, "center": 62, "sigma": 62, "euclidean": [62, 98], "see": [62, 63, 78, 98, 100], "http": [62, 63, 80, 98], "arxiv": [62, 80], "org": [62, 80, 100], "pdf": 62, "1809": 62, "06166": 62, "hold": 63, "alter": 63, "dure": [63, 70, 71, 72, 78], "node_definit": [63, 64], "edge_definit": 63, "geometri": 63, "nodedefinit": [63, 64, 65, 66], "truth_dict": 63, "custom_label_funct": 63, "loss_weight": [63, 70, 71, 72], "data_path": 63, "shape": [63, 66, 73, 80], "num_nod": 63, "github": [63, 80, 100], "com": [63, 80, 100], "team": [63, 98], "blob": [63, 80], "getting_start": 63, "md": 63, "your": [64, 98, 100], "nodesaspuls": [65, 66], "num_puls": 66, "overridden": 66, "set_number_of_input": 66, "measur": [66, 73], "cherenkov": 66, "radiat": 66, "train_dataload": 67, "val_dataload": 67, "max_epoch": 67, "gpu": [67, 68, 84, 100], "ckpt_path": 67, "log_every_n_step": 67, "gradient_clip_v": 67, "distribution_strategi": [67, 68], "trainer_kwarg": 67, "pytorch_lightn": [67, 95], "trainer": [67, 78, 81], "predict_as_datafram": [67, 68], "additional_attribut": [67, 68, 81], "save_state_dict": 67, "load_state_dict": 67, "karg": 67, "trust": 67, "enough": 67, "eval": [67, 100], "lambda": 67, "consequ": 67, "optimizer_class": 68, "optim": [68, 78], "adam": 68, "optimizer_kwarg": 68, "scheduler_class": 68, "scheduler_kwarg": 68, "scheduler_config": 68, "target_label": [68, 70, 71, 72], "target": [68, 70, 71, 72, 80, 91], "prediction_label": [68, 70, 71, 72], "configure_optim": 68, "shared_step": 68, "batch_idx": 68, "share": 68, "training_step": 68, "train_batch": 68, "validation_step": 68, "val_batch": 68, "compute_loss": [68, 70, 71, 72], "pred": [68, 72], "verbos": [68, 78], "activ": [68, 72, 98, 100], "mode": [68, 72], "deactiv": [68, 72], "multiclassclassificationtask": [69, 70], "binaryclassificationtask": [69, 70], "binaryclassificationtasklogit": [69, 70], "azimuthreconstructionwithkappa": [69, 71], "azimuthreconstruct": [69, 71], "directionreconstructionwithkappa": [69, 71], "zenithreconstruct": [69, 71], "zenithreconstructionwithkappa": [69, 71], "energyreconstruct": [69, 71], "energyreconstructionwithpow": [69, 71], "energyreconstructionwithuncertainti": [69, 71], "vertexreconstruct": [69, 71], "positionreconstruct": [69, 71], "timereconstruct": [69, 71], "inelasticityreconstruct": [69, 71], "identitytask": [69, 70, 72], "arg": [70, 72, 80, 84, 86, 91, 95], "classifi": 70, "untransform": 70, "logit": [70, 80], "affin": [70, 71, 72], "hidden_s": [70, 71, 72], "transform_prediction_and_target": [70, 71, 72], "transform_target": [70, 71, 72], "transform_infer": [70, 71, 72], "transform_support": [70, 71, 72], "binari": [70, 80], "feed": [70, 71, 72], "lossfunct": [70, 71, 72, 77, 80], "auto": [70, 71, 72], "matic": [70, 71, 72], "_pred": [70, 71, 72], "numer": [70, 71, 72], "stabl": [70, 71, 72], "log10": [70, 71, 72, 82], "rather": [70, 71, 72, 95], "conjunct": [70, 71, 72], "invers": [70, 71, 72], "recov": [70, 71, 72], "minimum": [70, 71, 72], "restrict": [70, 71, 72, 80], "invert": [70, 71, 72], "1e6": [70, 71, 72], "default_target_label": [70, 71, 72], "default_prediction_label": [70, 71, 72], "target_pr": 70, "angl": [71, 79], "kappa": [71, 80], "var": 71, "azimuth_pr": 71, "azimuth_kappa": 71, "3d": [71, 80], "vmf": 71, "dir_x_pr": 71, "dir_y_pr": 71, "dir_z_pr": 71, "direction_kappa": 71, "zenith_pr": 71, "zenith_kappa": 71, "energy_pr": 71, "uncertainti": 71, "energy_sigma": 71, "vertex": 71, "position_x_pr": 71, "position_y_pr": 71, "position_z_pr": 71, "interaction_time_pr": 71, "interact": 71, "hadron": 71, "inelasticity_pr": 71, "wrt": 72, "train_ev": 72, "xyzt": 73, "homophili": 73, "notic": [73, 80], "xyz_coord": 73, "pairwis": 73, "nb_dom": 73, "updat": [73, 75, 78], "config_updat": [74, 75], "weightfitt": [74, 75, 77, 82], "contourfitt": [74, 75], "read_entri": [74, 76], "plot_2d_contour": [74, 76], "plot_1d_contour": [74, 76], "contour": [75, 76], "config_path": 75, "new_config_path": 75, "dummy_sect": 75, "temp": 75, "dummi": 75, "section": 75, "header": 75, "configupdat": 75, "programat": 75, "statistical_fit": 75, "fit_weight": [75, 82], "config_outdir": 75, "weight_nam": [75, 82], "pisa_config_dict": 75, "add_to_databas": [75, 82], "flux": 75, "_database_path": 75, "statist": 75, "effect": [75, 78, 98], "account": 75, "systemat": 75, "hypersurfac": 75, "assumpt": 75, "regard": 75, "pipeline_path": 75, "post_fix": 75, "include_retro": 75, "fit_1d_contour": 75, "run_nam": 75, "config_dict": 75, "grid_siz": 75, "theta23_minmax": 75, "36": 75, "54": 75, "dm31_minmax": 75, "1d": [75, 76], "fit_2d_contour": 75, "2d": [75, 76, 80], "content": 76, "contour_data": 76, "xlim": 76, "ylim": 76, "0023799999999999997": 76, "0025499999999999997": 76, "chi2_critical_valu": 76, "width": 76, "height": 76, "path_to_pisa_fit_result": 76, "name_of_my_model_in_fit": 76, "legend": 76, "color": 76, "linestyl": 76, "style": [76, 98], "line": [76, 78, 84], "upper": 76, "axi": 76, "605": 76, "critic": [76, 95], "chi2": 76, "90": 76, "cl": 76, "right": [76, 80], "176": 76, "inch": 76, "388": 76, "706": 76, "abov": [76, 80, 82, 100], "352": 76, "piecewiselinearlr": [77, 78], "progressbar": [77, 78], "mseloss": [77, 80], "rmseloss": [77, 80], "logcoshloss": [77, 80], "crossentropyloss": [77, 80], "binarycrossentropyloss": [77, 80], "logcmk": [77, 80], "vonmisesfisherloss": [77, 80], "vonmisesfisher2dloss": [77, 80], "euclideandistanceloss": [77, 80], "vonmisesfisher3dloss": [77, 80], "make_dataload": [77, 81], "make_train_validation_dataload": [77, 81], "get_predict": [77, 81], "save_result": [77, 81], "uniform": [77, 82], "bjoernlow": [77, 82], "mileston": 78, "factor": 78, "last_epoch": 78, "_lrschedul": 78, "interpol": 78, "linearli": 78, "denot": 78, "multipli": 78, "closest": 78, "vice": 78, "versa": 78, "wrap": [78, 88, 89], "epoch": [78, 84], "print": [78, 95], "stdout": 78, "get_lr": 78, "refresh_r": 78, "process_posit": 78, "tqdmprogressbar": 78, "progress": 78, "bar": 78, "customis": 78, "lightn": 78, "init_validation_tqdm": 78, "overrid": 78, "init_predict_tqdm": 78, "init_test_tqdm": 78, "init_train_tqdm": 78, "get_metr": 78, "on_train_epoch_start": 78, "previou": 78, "behaviour": 78, "on_train_epoch_end": 78, "don": [78, 100], "duplciat": 78, "runtim": [79, 100], "azimuth_kei": 79, "zenith_kei": 79, "access": [79, 100], "azimiuth": 79, "return_el": 80, "elementwis": 80, "term": 80, "squar": 80, "root": [80, 100], "cosh": 80, "act": 80, "cross": 80, "entropi": 80, "num_class": 80, "softmax": 80, "ed": 80, "probabl": 80, "mit": 80, "licens": 80, "copyright": 80, "2019": 80, "ryabinin": 80, "permiss": 80, "herebi": 80, "person": 80, "copi": 80, "document": 80, "deal": 80, "modifi": 80, "publish": 80, "sublicens": 80, "sell": 80, "permit": 80, "whom": 80, "furnish": 80, "so": [80, 100], "subject": 80, "condit": 80, "shall": 80, "substanti": 80, "portion": 80, "THE": 80, "AS": 80, "warranti": 80, "OF": 80, "kind": 80, "OR": 80, "impli": 80, "BUT": 80, "TO": 80, "merchant": 80, "FOR": 80, "particular": [80, 98], "AND": 80, "noninfring": 80, "IN": 80, "NO": 80, "holder": 80, "BE": 80, "liabl": 80, "claim": 80, "damag": 80, "liabil": 80, "action": 80, "contract": 80, "tort": 80, "aris": 80, "WITH": 80, "_____________________": 80, "mryab": 80, "vmf_loss": 80, "master": 80, "py": [80, 100], "bessel": 80, "exponenti": 80, "ditto": 80, "iv": 80, "1812": 80, "04616": 80, "spite": 80, "suggest": 80, "sec": 80, "paper": 80, "m": 80, "correct": 80, "static": [80, 98], "ctx": 80, "backward": 80, "grad_output": 80, "von": 80, "mise": 80, "fisher": 80, "log_cmk_exact": 80, "c_": 80, "exactli": [80, 95], "log_cmk_approx": 80, "approx": 80, "minu": 80, "sign": 80, "log_cmk": 80, "kappa_switch": 80, "diverg": 80, "700": 80, "float64": 80, "precis": 80, "unaccur": 80, "switch": 80, "three": 80, "database_indic": 81, "test_siz": 81, "node_level": 81, "tag": [81, 98, 100], "archiv": 81, "public": 82, "uniformweightfitt": 82, "bin": 82, "privat": 82, "_fit_weight": 82, "sql": 82, "desir": [82, 93], "np": 82, "happen": 82, "x_low": 82, "wherea": 82, "curv": 82, "base_config": [83, 85], "dataset_config": [83, 85], "training_config": [83, 85], "argumentpars": [83, 84], "is_gcd_fil": [83, 93], "is_i3_fil": [83, 93], "has_extens": [83, 93], "find_i3_fil": [83, 93], "has_icecube_packag": [83, 94], "has_torch_packag": [83, 94], "has_pisa_packag": [83, 94], "requires_icecub": [83, 94], "repeatfilt": [83, 95], "eps_lik": [83, 96], "consist": [84, 95, 98], "cli": 84, "pop_default": 84, "usag": 84, "descript": 84, "command": [84, 100], "standard_argu": 84, "home": [84, 100], "runner": 84, "lib": [84, 100], "python3": 84, "training_example_data_sqlit": 84, "earli": 84, "patienc": 84, "narg": 84, "50": 84, "example_energy_reconstruction_model": 84, "num": 84, "fetch": 84, "with_standard_argu": 84, "overwritten": [84, 86], "baseconfig": [85, 86, 87, 88, 89, 91], "get_all_argument_valu": [85, 86], "save_dataset_config": [85, 88], "save_model_config": [85, 89], "traverse_and_appli": [85, 90], "list_all_submodul": [85, 90], "get_all_grapnet_class": [85, 90], "is_graphnet_modul": [85, 90], "is_graphnet_class": [85, 90], "get_graphnet_class": [85, 90], "trainingconfig": [85, 91], "basemodel": [86, 88, 89], "keyword": [86, 91], "validationerror": [86, 91], "pydantic_cor": [86, 91], "__init__": [86, 88, 89, 91, 100], "__pydantic_self__": [86, 91], "dump": [86, 88, 89], "yaml": [86, 87], "as_dict": [86, 88, 89], "classvar": [86, 88, 89, 91], "configdict": [86, 88, 89, 91], "conform": [86, 88, 89, 91], "pydant": [86, 88, 89, 91], "model_field": [86, 88, 89, 91], "fieldinfo": [86, 88, 89, 91], "metadata": [86, 88, 89, 91], "about": [86, 88, 89, 91], "__fields__": [86, 88, 89, 91], "v1": [86, 88, 89, 91, 100], "re": [87, 100], "save_config": 87, "dataconfig": 88, "transpar": [88, 89, 98], "reproduc": [88, 89], "In": [88, 89, 100], "session": [88, 89], "anoth": [88, 89], "you": [88, 89, 98, 100], "still": 88, "csv": 88, "train_select": 88, "test_select": 88, "unambigu": [88, 89], "annot": [88, 89, 91], "nonetyp": 88, "init_fn": [88, 89], "trainabl": 89, "hyperparamet": 89, "instanti": 89, "thu": 89, "fn_kwarg": 90, "structur": 90, "moduletyp": 90, "grapnet": 90, "lookup": 90, "early_stopping_pati": 91, "system": [93, 100], "filenam": 93, "dir": 93, "search": 93, "test_funct": 94, "filter": 95, "repeat": 95, "nb_repeats_allow": 95, "record": 95, "logrecord": 95, "clear": 95, "intuit": 95, "composit": 95, "loggeradapt": 95, "clash": 95, "setlevel": 95, "deleg": 95, "msg": 95, "warn": 95, "info": [95, 100], "debug": 95, "warning_onc": 95, "onc": 95, "handler": 95, "file_handl": 95, "filehandl": 95, "stream_handl": 95, "streamhandl": 95, "assort": 96, "ep": 96, "api": 97, "To": [98, 100], "sure": [98, 100], "smooth": 98, "guidelin": 98, "guid": 98, "encourag": 98, "contributor": 98, "discuss": 98, "bug": 98, "anyth": 98, "place": 98, "describ": 98, "yourself": 98, "ownership": 98, "prioriti": 98, "situat": 98, "lot": 98, "effort": 98, "go": 98, "turn": 98, "outsid": 98, "scope": 98, "better": 98, "fork": 98, "repo": 98, "dedic": 98, "branch": [98, 100], "repositori": 98, "own": [98, 100], "accept": 98, "autom": 98, "review": 98, "pep8": 98, "docstr": 98, "googl": 98, "hint": 98, "adher": 98, "pep": 98, "pylint": 98, "flake8": 98, "black": 98, "well": 98, "recommend": [98, 100], "mypi": 98, "pydocstyl": 98, "docformatt": 98, "commit": 98, "hook": 98, "instal": 98, "come": 98, "pip": [98, 100], "Then": 98, "everytim": 98, "pep257": 98, "concept": 98, "ljvmiranda921": 98, "io": 98, "notebook": 98, "2018": 98, "06": 98, "21": 98, "precommit": 98, "environ": 100, "virtual": 100, "anaconda": 100, "prove": 100, "instruct": 100, "setup": 100, "want": 100, "part": 100, "achiev": 100, "bash": 100, "shell": 100, "cvmf": 100, "opensciencegrid": 100, "py3": 100, "v4": 100, "sh": 100, "rhel_7_x86_64": 100, "metaproject": 100, "env": 100, "alia": 100, "script": 100, "With": 100, "now": 100, "light": 100, "extra": 100, "geometr": 100, "won": 100, "later": 100, "torch_cpu": 100, "txt": 100, "cpu": 100, "torch_gpu": 100, "prefer": 100, "unix": 100, "git": 100, "clone": 100, "usernam": 100, "cd": 100, "conda": 100, "gcc_linux": 100, "64": 100, "gxx_linux": 100, "libgcc": 100, "cudatoolkit": 100, "11": 100, "forg": 100, "torch_maco": 100, "On": 100, "maco": 100, "box": 100, "compil": 100, "gcc": 100, "date": 100, "possibli": 100, "cuda": 100, "toolkit": 100, "recent": 100, "omit": 100, "newer": 100, "export": 100, "ld_library_path": 100, "anaconda3": 100, "miniconda3": 100, "bashrc": 100, "librari": 100, "rm": 100, "asogaard": 100, "latest": 100, "dc423315742c": 100, "01_icetrai": 100, "01_convert_i3_fil": 100, "2023": 100, "01": 100, "24": 100, "41": 100, "27": 100, "graphnet_20230124": 100, "134127": 100, "46": 100, "convert_i3_fil": 100, "ic86": 100, "thread": 100, "00": 100, "79": 100, "42": 100, "26": 100, "413": 100, "88it": 100, "specialis": 100, "ones": 100, "push": 100, "vx": 100}, "objects": {"": [[1, 0, 0, "-", "graphnet"]], "graphnet": [[2, 0, 0, "-", "constants"], [3, 0, 0, "-", "data"], [41, 0, 0, "-", "deployment"], [45, 0, 0, "-", "models"], [74, 0, 0, "-", "pisa"], [77, 0, 0, "-", "training"], [83, 0, 0, "-", "utilities"]], "graphnet.data": [[4, 0, 0, "-", "constants"], [5, 0, 0, "-", "dataconverter"], [6, 0, 0, "-", "dataloader"], [7, 0, 0, "-", "dataset"], [14, 0, 0, "-", "extractors"], [31, 0, 0, "-", "parquet"], [33, 0, 0, "-", "pipeline"], [34, 0, 0, "-", "sqlite"], [37, 0, 0, "-", "utilities"]], "graphnet.data.constants": [[4, 1, 1, "", "FEATURES"], [4, 1, 1, "", "TRUTH"]], "graphnet.data.constants.FEATURES": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.constants.TRUTH": [[4, 2, 1, "", "DEEPCORE"], [4, 2, 1, "", "ICECUBE86"], [4, 2, 1, "", "KAGGLE"], [4, 2, 1, "", "PROMETHEUS"], [4, 2, 1, "", "UPGRADE"]], "graphnet.data.dataconverter": [[5, 1, 1, "", "DataConverter"], [5, 1, 1, "", "FileSet"], [5, 5, 1, "", "cache_output_files"], [5, 5, 1, "", "init_global_index"]], "graphnet.data.dataconverter.DataConverter": [[5, 3, 1, "", "execute"], [5, 4, 1, "", "file_suffix"], [5, 3, 1, "", "get_map_function"], [5, 3, 1, "", "merge_files"], [5, 3, 1, "", "save_data"]], "graphnet.data.dataconverter.FileSet": [[5, 2, 1, "", "gcd_file"], [5, 2, 1, "", "i3_file"]], "graphnet.data.dataloader": [[6, 1, 1, "", "DataLoader"], [6, 5, 1, "", "collate_fn"], [6, 5, 1, "", "do_shuffle"]], "graphnet.data.dataloader.DataLoader": [[6, 3, 1, "", "from_dataset_config"]], "graphnet.data.dataset": [[8, 0, 0, "-", "dataset"], [9, 0, 0, "-", "parquet"], [11, 0, 0, "-", "sqlite"]], "graphnet.data.dataset.dataset": [[8, 6, 1, "", "ColumnMissingException"], [8, 1, 1, "", "Dataset"], [8, 1, 1, "", "EnsembleDataset"], [8, 5, 1, "", "load_module"], [8, 5, 1, "", "parse_graph_definition"]], "graphnet.data.dataset.dataset.Dataset": [[8, 3, 1, "", "add_label"], [8, 3, 1, "", "concatenate"], [8, 3, 1, "", "from_config"], [8, 4, 1, "", "path"], [8, 3, 1, "", "query_table"], [8, 4, 1, "", "truth_table"]], "graphnet.data.dataset.parquet": [[10, 0, 0, "-", "parquet_dataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, 1, 1, "", "ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset": [[10, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite": [[12, 0, 0, "-", "sqlite_dataset"], [13, 0, 0, "-", "sqlite_dataset_perturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, 1, 1, "", "SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset": [[12, 3, 1, "", "query_table"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, 1, 1, "", "SQLiteDatasetPerturbed"]], "graphnet.data.extractors": [[15, 0, 0, "-", "i3extractor"], [16, 0, 0, "-", "i3featureextractor"], [17, 0, 0, "-", "i3genericextractor"], [18, 0, 0, "-", "i3hybridrecoextractor"], [19, 0, 0, "-", "i3ntmuonlabelsextractor"], [20, 0, 0, "-", "i3particleextractor"], [21, 0, 0, "-", "i3pisaextractor"], [22, 0, 0, "-", "i3quesoextractor"], [23, 0, 0, "-", "i3retroextractor"], [24, 0, 0, "-", "i3splinempeextractor"], [25, 0, 0, "-", "i3truthextractor"], [26, 0, 0, "-", "i3tumextractor"], [27, 0, 0, "-", "utilities"]], "graphnet.data.extractors.i3extractor": [[15, 1, 1, "", "I3Extractor"], [15, 1, 1, "", "I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor.I3Extractor": [[15, 4, 1, "", "name"], [15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3extractor.I3ExtractorCollection": [[15, 3, 1, "", "set_files"]], "graphnet.data.extractors.i3featureextractor": [[16, 1, 1, "", "I3FeatureExtractor"], [16, 1, 1, "", "I3FeatureExtractorIceCube86"], [16, 1, 1, "", "I3FeatureExtractorIceCubeDeepCore"], [16, 1, 1, "", "I3FeatureExtractorIceCubeUpgrade"], [16, 1, 1, "", "I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3genericextractor": [[17, 1, 1, "", "I3GenericExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, 1, 1, "", "I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, 1, 1, "", "I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, 1, 1, "", "I3ParticleExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, 1, 1, "", "I3PISAExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, 1, 1, "", "I3QUESOExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, 1, 1, "", "I3RetroExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, 1, 1, "", "I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, 1, 1, "", "I3TruthExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, 1, 1, "", "I3TUMExtractor"]], "graphnet.data.extractors.utilities": [[28, 0, 0, "-", "collections"], [29, 0, 0, "-", "frames"], [30, 0, 0, "-", "types"]], "graphnet.data.extractors.utilities.collections": [[28, 5, 1, "", "flatten_nested_dictionary"], [28, 5, 1, "", "serialise"], [28, 5, 1, "", "transpose_list_of_dicts"]], "graphnet.data.extractors.utilities.frames": [[29, 5, 1, "", "frame_is_montecarlo"], [29, 5, 1, "", "frame_is_noise"], [29, 5, 1, "", "get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.types": [[30, 5, 1, "", "break_cyclic_recursion"], [30, 5, 1, "", "cast_object_to_pure_python"], [30, 5, 1, "", "cast_pulse_series_to_pure_python"], [30, 5, 1, "", "get_member_variables"], [30, 5, 1, "", "is_boost_class"], [30, 5, 1, "", "is_boost_enum"], [30, 5, 1, "", "is_icecube_class"], [30, 5, 1, "", "is_method"], [30, 5, 1, "", "is_type"]], "graphnet.data.parquet": [[32, 0, 0, "-", "parquet_dataconverter"]], "graphnet.data.parquet.parquet_dataconverter": [[32, 1, 1, "", "ParquetDataConverter"]], "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter": [[32, 2, 1, "", "file_suffix"], [32, 3, 1, "", "merge_files"], [32, 3, 1, "", "save_data"]], "graphnet.data.pipeline": [[33, 1, 1, "", "InSQLitePipeline"]], "graphnet.data.sqlite": [[35, 0, 0, "-", "sqlite_dataconverter"], [36, 0, 0, "-", "sqlite_utilities"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, 1, 1, "", "SQLiteDataConverter"], [35, 5, 1, "", "construct_dataframe"], [35, 5, 1, "", "is_mc_tree"], [35, 5, 1, "", "is_pulse_map"]], "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter": [[35, 3, 1, "", "any_pulsemap_is_non_empty"], [35, 2, 1, "", "file_suffix"], [35, 3, 1, "", "merge_files"], [35, 3, 1, "", "save_data"]], "graphnet.data.sqlite.sqlite_utilities": [[36, 5, 1, "", "attach_index"], [36, 5, 1, "", "create_table"], [36, 5, 1, "", "create_table_and_save_to_sql"], [36, 5, 1, "", "database_exists"], [36, 5, 1, "", "database_table_exists"], [36, 5, 1, "", "run_sql_code"], [36, 5, 1, "", "save_to_sql"]], "graphnet.data.utilities": [[38, 0, 0, "-", "parquet_to_sqlite"], [39, 0, 0, "-", "random"], [40, 0, 0, "-", "string_selection_resolver"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, 1, 1, "", "ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter": [[38, 3, 1, "", "run"]], "graphnet.data.utilities.random": [[39, 5, 1, "", "pairwise_shuffle"]], "graphnet.data.utilities.string_selection_resolver": [[40, 1, 1, "", "StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver": [[40, 3, 1, "", "resolve"]], "graphnet.deployment.i3modules": [[44, 0, 0, "-", "graphnet_module"]], "graphnet.deployment.i3modules.graphnet_module": [[44, 1, 1, "", "GraphNeTI3Module"], [44, 1, 1, "", "I3InferenceModule"], [44, 1, 1, "", "I3PulseCleanerModule"]], "graphnet.models": [[46, 0, 0, "-", "coarsening"], [47, 0, 0, "-", "components"], [50, 0, 0, "-", "detector"], [54, 0, 0, "-", "gnn"], [60, 0, 0, "-", "graphs"], [67, 0, 0, "-", "model"], [68, 0, 0, "-", "standard_model"], [69, 0, 0, "-", "task"], [73, 0, 0, "-", "utils"]], "graphnet.models.coarsening": [[46, 1, 1, "", "AttributeCoarsening"], [46, 1, 1, "", "Coarsening"], [46, 1, 1, "", "CustomDOMCoarsening"], [46, 1, 1, "", "DOMAndTimeWindowCoarsening"], [46, 1, 1, "", "DOMCoarsening"], [46, 5, 1, "", "unbatch_edge_index"]], "graphnet.models.coarsening.Coarsening": [[46, 3, 1, "", "forward"], [46, 2, 1, "", "reduce_options"]], "graphnet.models.components": [[48, 0, 0, "-", "layers"], [49, 0, 0, "-", "pool"]], "graphnet.models.components.layers": [[48, 1, 1, "", "DynEdgeConv"], [48, 1, 1, "", "DynTrans"], [48, 1, 1, "", "EdgeConvTito"]], "graphnet.models.components.layers.DynEdgeConv": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.DynTrans": [[48, 3, 1, "", "forward"]], "graphnet.models.components.layers.EdgeConvTito": [[48, 3, 1, "", "forward"], [48, 3, 1, "", "message"], [48, 3, 1, "", "reset_parameters"]], "graphnet.models.components.pool": [[49, 5, 1, "", "group_by"], [49, 5, 1, "", "group_pulses_to_dom"], [49, 5, 1, "", "group_pulses_to_pmt"], [49, 5, 1, "", "min_pool"], [49, 5, 1, "", "min_pool_x"], [49, 5, 1, "", "std_pool"], [49, 5, 1, "", "std_pool_x"], [49, 5, 1, "", "sum_pool"], [49, 5, 1, "", "sum_pool_and_distribute"], [49, 5, 1, "", "sum_pool_x"]], "graphnet.models.detector": [[51, 0, 0, "-", "detector"], [52, 0, 0, "-", "icecube"], [53, 0, 0, "-", "prometheus"]], "graphnet.models.detector.detector": [[51, 1, 1, "", "Detector"]], "graphnet.models.detector.detector.Detector": [[51, 3, 1, "", "feature_map"], [51, 3, 1, "", "forward"]], "graphnet.models.detector.icecube": [[52, 1, 1, "", "IceCube86"], [52, 1, 1, "", "IceCubeDeepCore"], [52, 1, 1, "", "IceCubeKaggle"], [52, 1, 1, "", "IceCubeUpgrade"]], "graphnet.models.detector.icecube.IceCube86": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeDeepCore": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeKaggle": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.icecube.IceCubeUpgrade": [[52, 3, 1, "", "feature_map"]], "graphnet.models.detector.prometheus": [[53, 1, 1, "", "Prometheus"]], "graphnet.models.detector.prometheus.Prometheus": [[53, 3, 1, "", "feature_map"]], "graphnet.models.gnn": [[55, 0, 0, "-", "convnet"], [56, 0, 0, "-", "dynedge"], [57, 0, 0, "-", "dynedge_jinst"], [58, 0, 0, "-", "dynedge_kaggle_tito"], [59, 0, 0, "-", "gnn"]], "graphnet.models.gnn.convnet": [[55, 1, 1, "", "ConvNet"]], "graphnet.models.gnn.convnet.ConvNet": [[55, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge": [[56, 1, 1, "", "DynEdge"]], "graphnet.models.gnn.dynedge.DynEdge": [[56, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, 1, 1, "", "DynEdgeJINST"]], "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST": [[57, 3, 1, "", "forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, 1, 1, "", "DynEdgeTITO"]], "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO": [[58, 3, 1, "", "forward"]], "graphnet.models.gnn.gnn": [[59, 1, 1, "", "GNN"]], "graphnet.models.gnn.gnn.GNN": [[59, 3, 1, "", "forward"], [59, 4, 1, "", "nb_inputs"], [59, 4, 1, "", "nb_outputs"]], "graphnet.models.graphs": [[61, 0, 0, "-", "edges"], [63, 0, 0, "-", "graph_definition"], [64, 0, 0, "-", "graphs"], [65, 0, 0, "-", "nodes"]], "graphnet.models.graphs.edges": [[62, 0, 0, "-", "edges"]], "graphnet.models.graphs.edges.edges": [[62, 1, 1, "", "EdgeDefinition"], [62, 1, 1, "", "EuclideanEdges"], [62, 1, 1, "", "KNNEdges"], [62, 1, 1, "", "RadialEdges"]], "graphnet.models.graphs.edges.edges.EdgeDefinition": [[62, 3, 1, "", "forward"]], "graphnet.models.graphs.graph_definition": [[63, 1, 1, "", "GraphDefinition"]], "graphnet.models.graphs.graph_definition.GraphDefinition": [[63, 3, 1, "", "forward"]], "graphnet.models.graphs.graphs": [[64, 1, 1, "", "KNNGraph"]], "graphnet.models.graphs.nodes": [[66, 0, 0, "-", "nodes"]], "graphnet.models.graphs.nodes.nodes": [[66, 1, 1, "", "NodeDefinition"], [66, 1, 1, "", "NodesAsPulses"]], "graphnet.models.graphs.nodes.nodes.NodeDefinition": [[66, 3, 1, "", "forward"], [66, 4, 1, "", "nb_outputs"], [66, 3, 1, "", "set_number_of_inputs"]], "graphnet.models.model": [[67, 1, 1, "", "Model"]], "graphnet.models.model.Model": [[67, 3, 1, "", "fit"], [67, 3, 1, "", "forward"], [67, 3, 1, "", "from_config"], [67, 3, 1, "", "load"], [67, 3, 1, "", "load_state_dict"], [67, 3, 1, "", "predict"], [67, 3, 1, "", "predict_as_dataframe"], [67, 3, 1, "", "save"], [67, 3, 1, "", "save_state_dict"]], "graphnet.models.standard_model": [[68, 1, 1, "", "StandardModel"]], "graphnet.models.standard_model.StandardModel": [[68, 3, 1, "", "compute_loss"], [68, 3, 1, "", "configure_optimizers"], [68, 3, 1, "", "forward"], [68, 3, 1, "", "inference"], [68, 3, 1, "", "predict"], [68, 3, 1, "", "predict_as_dataframe"], [68, 4, 1, "", "prediction_labels"], [68, 3, 1, "", "shared_step"], [68, 4, 1, "", "target_labels"], [68, 3, 1, "", "train"], [68, 3, 1, "", "training_step"], [68, 3, 1, "", "validation_step"]], "graphnet.models.task": [[70, 0, 0, "-", "classification"], [71, 0, 0, "-", "reconstruction"], [72, 0, 0, "-", "task"]], "graphnet.models.task.classification": [[70, 1, 1, "", "BinaryClassificationTask"], [70, 1, 1, "", "BinaryClassificationTaskLogits"], [70, 1, 1, "", "MulticlassClassificationTask"]], "graphnet.models.task.classification.BinaryClassificationTask": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.classification.BinaryClassificationTaskLogits": [[70, 2, 1, "", "default_prediction_labels"], [70, 2, 1, "", "default_target_labels"], [70, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction": [[71, 1, 1, "", "AzimuthReconstruction"], [71, 1, 1, "", "AzimuthReconstructionWithKappa"], [71, 1, 1, "", "DirectionReconstructionWithKappa"], [71, 1, 1, "", "EnergyReconstruction"], [71, 1, 1, "", "EnergyReconstructionWithPower"], [71, 1, 1, "", "EnergyReconstructionWithUncertainty"], [71, 1, 1, "", "InelasticityReconstruction"], [71, 1, 1, "", "PositionReconstruction"], [71, 1, 1, "", "TimeReconstruction"], [71, 1, 1, "", "VertexReconstruction"], [71, 1, 1, "", "ZenithReconstruction"], [71, 1, 1, "", "ZenithReconstructionWithKappa"]], "graphnet.models.task.reconstruction.AzimuthReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithPower": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.InelasticityReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.PositionReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.TimeReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.VertexReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstruction": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa": [[71, 2, 1, "", "default_prediction_labels"], [71, 2, 1, "", "default_target_labels"], [71, 2, 1, "", "nb_inputs"]], "graphnet.models.task.task": [[72, 1, 1, "", "IdentityTask"], [72, 1, 1, "", "Task"]], "graphnet.models.task.task.IdentityTask": [[72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 4, 1, "", "nb_inputs"]], "graphnet.models.task.task.Task": [[72, 3, 1, "", "compute_loss"], [72, 4, 1, "", "default_prediction_labels"], [72, 4, 1, "", "default_target_labels"], [72, 3, 1, "", "forward"], [72, 3, 1, "", "inference"], [72, 4, 1, "", "nb_inputs"], [72, 3, 1, "", "train_eval"]], "graphnet.models.utils": [[73, 5, 1, "", "calculate_distance_matrix"], [73, 5, 1, "", "calculate_xyzt_homophily"], [73, 5, 1, "", "knn_graph_batch"]], "graphnet.pisa": [[75, 0, 0, "-", "fitting"], [76, 0, 0, "-", "plotting"]], "graphnet.pisa.fitting": [[75, 1, 1, "", "ContourFitter"], [75, 1, 1, "", "WeightFitter"], [75, 5, 1, "", "config_updater"]], "graphnet.pisa.fitting.ContourFitter": [[75, 3, 1, "", "fit_1d_contour"], [75, 3, 1, "", "fit_2d_contour"]], "graphnet.pisa.fitting.WeightFitter": [[75, 3, 1, "", "fit_weights"]], "graphnet.pisa.plotting": [[76, 5, 1, "", "plot_1D_contour"], [76, 5, 1, "", "plot_2D_contour"], [76, 5, 1, "", "read_entry"]], "graphnet.training": [[78, 0, 0, "-", "callbacks"], [79, 0, 0, "-", "labels"], [80, 0, 0, "-", "loss_functions"], [81, 0, 0, "-", "utils"], [82, 0, 0, "-", "weight_fitting"]], "graphnet.training.callbacks": [[78, 1, 1, "", "PiecewiseLinearLR"], [78, 1, 1, "", "ProgressBar"]], "graphnet.training.callbacks.PiecewiseLinearLR": [[78, 3, 1, "", "get_lr"]], "graphnet.training.callbacks.ProgressBar": [[78, 3, 1, "", "get_metrics"], [78, 3, 1, "", "init_predict_tqdm"], [78, 3, 1, "", "init_test_tqdm"], [78, 3, 1, "", "init_train_tqdm"], [78, 3, 1, "", "init_validation_tqdm"], [78, 3, 1, "", "on_train_epoch_end"], [78, 3, 1, "", "on_train_epoch_start"]], "graphnet.training.labels": [[79, 1, 1, "", "Direction"], [79, 1, 1, "", "Label"]], "graphnet.training.labels.Label": [[79, 4, 1, "", "key"]], "graphnet.training.loss_functions": [[80, 1, 1, "", "BinaryCrossEntropyLoss"], [80, 1, 1, "", "CrossEntropyLoss"], [80, 1, 1, "", "EuclideanDistanceLoss"], [80, 1, 1, "", "LogCMK"], [80, 1, 1, "", "LogCoshLoss"], [80, 1, 1, "", "LossFunction"], [80, 1, 1, "", "MSELoss"], [80, 1, 1, "", "RMSELoss"], [80, 1, 1, "", "VonMisesFisher2DLoss"], [80, 1, 1, "", "VonMisesFisher3DLoss"], [80, 1, 1, "", "VonMisesFisherLoss"]], "graphnet.training.loss_functions.LogCMK": [[80, 3, 1, "", "backward"], [80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.LossFunction": [[80, 3, 1, "", "forward"]], "graphnet.training.loss_functions.VonMisesFisherLoss": [[80, 3, 1, "", "log_cmk"], [80, 3, 1, "", "log_cmk_approx"], [80, 3, 1, "", "log_cmk_exact"]], "graphnet.training.utils": [[81, 5, 1, "", "collate_fn"], [81, 5, 1, "", "get_predictions"], [81, 5, 1, "", "make_dataloader"], [81, 5, 1, "", "make_train_validation_dataloader"], [81, 5, 1, "", "save_results"]], "graphnet.training.weight_fitting": [[82, 1, 1, "", "BjoernLow"], [82, 1, 1, "", "Uniform"], [82, 1, 1, "", "WeightFitter"]], "graphnet.training.weight_fitting.WeightFitter": [[82, 3, 1, "", "fit"]], "graphnet.utilities": [[84, 0, 0, "-", "argparse"], [85, 0, 0, "-", "config"], [92, 0, 0, "-", "decorators"], [93, 0, 0, "-", "filesys"], [94, 0, 0, "-", "imports"], [95, 0, 0, "-", "logging"], [96, 0, 0, "-", "maths"]], "graphnet.utilities.argparse": [[84, 1, 1, "", "ArgumentParser"], [84, 1, 1, "", "Options"]], "graphnet.utilities.argparse.ArgumentParser": [[84, 2, 1, "", "standard_arguments"], [84, 3, 1, "", "with_standard_arguments"]], "graphnet.utilities.argparse.Options": [[84, 3, 1, "", "contains"], [84, 3, 1, "", "pop_default"]], "graphnet.utilities.config": [[86, 0, 0, "-", "base_config"], [87, 0, 0, "-", "configurable"], [88, 0, 0, "-", "dataset_config"], [89, 0, 0, "-", "model_config"], [90, 0, 0, "-", "parsing"], [91, 0, 0, "-", "training_config"]], "graphnet.utilities.config.base_config": [[86, 1, 1, "", "BaseConfig"], [86, 5, 1, "", "get_all_argument_values"]], "graphnet.utilities.config.base_config.BaseConfig": [[86, 3, 1, "", "as_dict"], [86, 3, 1, "", "dump"], [86, 3, 1, "", "load"], [86, 2, 1, "", "model_config"], [86, 2, 1, "", "model_fields"]], "graphnet.utilities.config.configurable": [[87, 1, 1, "", "Configurable"]], "graphnet.utilities.config.configurable.Configurable": [[87, 4, 1, "", "config"], [87, 3, 1, "", "from_config"], [87, 3, 1, "", "save_config"]], "graphnet.utilities.config.dataset_config": [[88, 1, 1, "", "DatasetConfig"], [88, 5, 1, "", "save_dataset_config"]], "graphnet.utilities.config.dataset_config.DatasetConfig": [[88, 3, 1, "", "as_dict"], [88, 2, 1, "", "features"], [88, 2, 1, "", "graph_definition"], [88, 2, 1, "", "index_column"], [88, 2, 1, "", "loss_weight_column"], [88, 2, 1, "", "loss_weight_default_value"], [88, 2, 1, "", "loss_weight_table"], [88, 2, 1, "", "model_config"], [88, 2, 1, "", "model_fields"], [88, 2, 1, "", "node_truth"], [88, 2, 1, "", "node_truth_table"], [88, 2, 1, "", "path"], [88, 2, 1, "", "pulsemaps"], [88, 2, 1, "", "seed"], [88, 2, 1, "", "selection"], [88, 2, 1, "", "string_selection"], [88, 2, 1, "", "truth"], [88, 2, 1, "", "truth_table"]], "graphnet.utilities.config.model_config": [[89, 1, 1, "", "ModelConfig"], [89, 5, 1, "", "save_model_config"]], "graphnet.utilities.config.model_config.ModelConfig": [[89, 2, 1, "", "arguments"], [89, 3, 1, "", "as_dict"], [89, 2, 1, "", "class_name"], [89, 2, 1, "", "model_config"], [89, 2, 1, "", "model_fields"]], "graphnet.utilities.config.parsing": [[90, 5, 1, "", "get_all_grapnet_classes"], [90, 5, 1, "", "get_graphnet_classes"], [90, 5, 1, "", "is_graphnet_class"], [90, 5, 1, "", "is_graphnet_module"], [90, 5, 1, "", "list_all_submodules"], [90, 5, 1, "", "traverse_and_apply"]], "graphnet.utilities.config.training_config": [[91, 1, 1, "", "TrainingConfig"]], "graphnet.utilities.config.training_config.TrainingConfig": [[91, 2, 1, "", "dataloader"], [91, 2, 1, "", "early_stopping_patience"], [91, 2, 1, "", "fit"], [91, 2, 1, "", "model_config"], [91, 2, 1, "", "model_fields"], [91, 2, 1, "", "target"]], "graphnet.utilities.filesys": [[93, 5, 1, "", "find_i3_files"], [93, 5, 1, "", "has_extension"], [93, 5, 1, "", "is_gcd_file"], [93, 5, 1, "", "is_i3_file"]], "graphnet.utilities.imports": [[94, 5, 1, "", "has_icecube_package"], [94, 5, 1, "", "has_pisa_package"], [94, 5, 1, "", "has_torch_package"], [94, 5, 1, "", "requires_icecube"]], "graphnet.utilities.logging": [[95, 1, 1, "", "Logger"], [95, 1, 1, "", "RepeatFilter"]], "graphnet.utilities.logging.Logger": [[95, 3, 1, "", "critical"], [95, 3, 1, "", "debug"], [95, 3, 1, "", "error"], [95, 4, 1, "", "file_handlers"], [95, 4, 1, "", "handlers"], [95, 3, 1, "", "info"], [95, 3, 1, "", "setLevel"], [95, 4, 1, "", "stream_handlers"], [95, 3, 1, "", "warning"], [95, 3, 1, "", "warning_once"]], "graphnet.utilities.logging.RepeatFilter": [[95, 3, 1, "", "filter"], [95, 2, 1, "", "nb_repeats_allowed"]], "graphnet.utilities.maths": [[96, 5, 1, "", "eps_like"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:attribute", "3": "py:method", "4": "py:property", "5": "py:function", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "attribute", "Python attribute"], "3": ["py", "method", "Python method"], "4": ["py", "property", "Python property"], "5": ["py", "function", "Python function"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"about": [0, 99], "impact": [0, 99], "usag": [0, 99], "acknowledg": [0, 99], "api": 1, "constant": [2, 4], "data": 3, "dataconvert": 5, "dataload": 6, "dataset": [7, 8], "parquet": [9, 31], "parquet_dataset": 10, "sqlite": [11, 34], "sqlite_dataset": 12, "sqlite_dataset_perturb": 13, "extractor": 14, "i3extractor": 15, "i3featureextractor": 16, "i3genericextractor": 17, "i3hybridrecoextractor": 18, "i3ntmuonlabelsextractor": 19, "i3particleextractor": 20, "i3pisaextractor": 21, "i3quesoextractor": 22, "i3retroextractor": 23, "i3splinempeextractor": 24, "i3truthextractor": 25, "i3tumextractor": 26, "util": [27, 37, 73, 81, 83], "collect": 28, "frame": 29, "type": 30, "parquet_dataconvert": 32, "pipelin": 33, "sqlite_dataconvert": 35, "sqlite_util": 36, "parquet_to_sqlit": 38, "random": 39, "string_selection_resolv": 40, "deploy": [41, 43], "i3modul": 42, "graphnet_modul": 44, "model": [45, 67], "coarsen": 46, "compon": 47, "layer": 48, "pool": 49, "detector": [50, 51], "icecub": 52, "prometheu": 53, "gnn": [54, 59], "convnet": 55, "dynedg": 56, "dynedge_jinst": 57, "dynedge_kaggle_tito": 58, "graph": [60, 64], "edg": [61, 62], "graph_definit": 63, "node": [65, 66], "standard_model": 68, "task": [69, 72], "classif": 70, "reconstruct": 71, "pisa": 74, "fit": 75, "plot": 76, "train": 77, "callback": 78, "label": 79, "loss_funct": 80, "weight_fit": 82, "argpars": 84, "config": 85, "base_config": 86, "configur": 87, "dataset_config": 88, "model_config": 89, "pars": 90, "training_config": 91, "decor": 92, "filesi": 93, "import": 94, "log": 95, "math": 96, "src": 97, "contribut": 98, "github": 98, "issu": 98, "pull": 98, "request": 98, "convent": 98, "code": 98, "qualiti": 98, "instal": 100, "icetrai": 100, "stand": 100, "alon": 100, "run": 100, "docker": 100}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 60}, "alltitles": {"About": [[0, "about"], [99, "about"]], "Impact": [[0, "impact"], [99, "impact"]], "Usage": [[0, "usage"], [99, "usage"]], "Acknowledgements": [[0, "acknowledgements"], [99, "acknowledgements"]], "API": [[1, "module-graphnet"]], "constants": [[2, "module-graphnet.constants"], [4, "module-graphnet.data.constants"]], "data": [[3, "module-graphnet.data"]], "dataconverter": [[5, "module-graphnet.data.dataconverter"]], "dataloader": [[6, "module-graphnet.data.dataloader"]], "dataset": [[7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"]], "parquet": [[9, "module-graphnet.data.dataset.parquet"], [31, "module-graphnet.data.parquet"]], "parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "sqlite": [[11, "module-graphnet.data.dataset.sqlite"], [34, "module-graphnet.data.sqlite"]], "sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "utilities": [[27, "module-graphnet.data.extractors.utilities"], [37, "module-graphnet.data.utilities"], [83, "module-graphnet.utilities"]], "collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "types": [[30, "module-graphnet.data.extractors.utilities.types"]], "parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "pipeline": [[33, "module-graphnet.data.pipeline"]], "sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "random": [[39, "module-graphnet.data.utilities.random"]], "string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "deployment": [[41, "module-graphnet.deployment"]], "i3modules": [[42, "i3modules"]], "deployer": [[43, "deployer"]], "graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "models": [[45, "module-graphnet.models"]], "coarsening": [[46, "module-graphnet.models.coarsening"]], "components": [[47, "module-graphnet.models.components"]], "layers": [[48, "module-graphnet.models.components.layers"]], "pool": [[49, "module-graphnet.models.components.pool"]], "detector": [[50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"]], "icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "gnn": [[54, "module-graphnet.models.gnn"], [59, "module-graphnet.models.gnn.gnn"]], "convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "graphs": [[60, "module-graphnet.models.graphs"], [64, "module-graphnet.models.graphs.graphs"]], "edges": [[61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"]], "graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "nodes": [[65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"]], "model": [[67, "module-graphnet.models.model"]], "standard_model": [[68, "module-graphnet.models.standard_model"]], "task": [[69, "module-graphnet.models.task"], [72, "module-graphnet.models.task.task"]], "classification": [[70, "module-graphnet.models.task.classification"]], "reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "utils": [[73, "module-graphnet.models.utils"], [81, "module-graphnet.training.utils"]], "pisa": [[74, "module-graphnet.pisa"]], "fitting": [[75, "module-graphnet.pisa.fitting"]], "plotting": [[76, "module-graphnet.pisa.plotting"]], "training": [[77, "module-graphnet.training"]], "callbacks": [[78, "module-graphnet.training.callbacks"]], "labels": [[79, "module-graphnet.training.labels"]], "loss_functions": [[80, "module-graphnet.training.loss_functions"]], "weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "argparse": [[84, "module-graphnet.utilities.argparse"]], "config": [[85, "module-graphnet.utilities.config"]], "base_config": [[86, "module-graphnet.utilities.config.base_config"]], "configurable": [[87, "module-graphnet.utilities.config.configurable"]], "dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "model_config": [[89, "module-graphnet.utilities.config.model_config"]], "parsing": [[90, "module-graphnet.utilities.config.parsing"]], "training_config": [[91, "module-graphnet.utilities.config.training_config"]], "decorators": [[92, "module-graphnet.utilities.decorators"]], "filesys": [[93, "module-graphnet.utilities.filesys"]], "imports": [[94, "module-graphnet.utilities.imports"]], "logging": [[95, "module-graphnet.utilities.logging"]], "maths": [[96, "module-graphnet.utilities.maths"]], "src": [[97, "src"]], "Contribute": [[98, "contribute"]], "GitHub issues": [[98, "github-issues"]], "Pull requests": [[98, "pull-requests"]], "Conventions": [[98, "conventions"]], "Code quality": [[98, "code-quality"]], "Install": [[100, "install"]], "Installing with IceTray": [[100, "installing-with-icetray"]], "Installing stand-alone": [[100, "installing-stand-alone"]], "Running in Docker": [[100, "running-in-docker"]]}, "indexentries": {"graphnet": [[1, "module-graphnet"]], "module": [[1, "module-graphnet"], [2, "module-graphnet.constants"], [3, "module-graphnet.data"], [4, "module-graphnet.data.constants"], [5, "module-graphnet.data.dataconverter"], [6, "module-graphnet.data.dataloader"], [7, "module-graphnet.data.dataset"], [8, "module-graphnet.data.dataset.dataset"], [9, "module-graphnet.data.dataset.parquet"], [10, "module-graphnet.data.dataset.parquet.parquet_dataset"], [11, "module-graphnet.data.dataset.sqlite"], [12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"], [13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"], [14, "module-graphnet.data.extractors"], [15, "module-graphnet.data.extractors.i3extractor"], [16, "module-graphnet.data.extractors.i3featureextractor"], [17, "module-graphnet.data.extractors.i3genericextractor"], [18, "module-graphnet.data.extractors.i3hybridrecoextractor"], [19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"], [20, "module-graphnet.data.extractors.i3particleextractor"], [21, "module-graphnet.data.extractors.i3pisaextractor"], [22, "module-graphnet.data.extractors.i3quesoextractor"], [23, "module-graphnet.data.extractors.i3retroextractor"], [24, "module-graphnet.data.extractors.i3splinempeextractor"], [25, "module-graphnet.data.extractors.i3truthextractor"], [26, "module-graphnet.data.extractors.i3tumextractor"], [27, "module-graphnet.data.extractors.utilities"], [28, "module-graphnet.data.extractors.utilities.collections"], [29, "module-graphnet.data.extractors.utilities.frames"], [30, "module-graphnet.data.extractors.utilities.types"], [31, "module-graphnet.data.parquet"], [32, "module-graphnet.data.parquet.parquet_dataconverter"], [33, "module-graphnet.data.pipeline"], [34, "module-graphnet.data.sqlite"], [35, "module-graphnet.data.sqlite.sqlite_dataconverter"], [36, "module-graphnet.data.sqlite.sqlite_utilities"], [37, "module-graphnet.data.utilities"], [38, "module-graphnet.data.utilities.parquet_to_sqlite"], [39, "module-graphnet.data.utilities.random"], [40, "module-graphnet.data.utilities.string_selection_resolver"], [41, "module-graphnet.deployment"], [44, "module-graphnet.deployment.i3modules.graphnet_module"], [45, "module-graphnet.models"], [46, "module-graphnet.models.coarsening"], [47, "module-graphnet.models.components"], [48, "module-graphnet.models.components.layers"], [49, "module-graphnet.models.components.pool"], [50, "module-graphnet.models.detector"], [51, "module-graphnet.models.detector.detector"], [52, "module-graphnet.models.detector.icecube"], [53, "module-graphnet.models.detector.prometheus"], [54, "module-graphnet.models.gnn"], [55, "module-graphnet.models.gnn.convnet"], [56, "module-graphnet.models.gnn.dynedge"], [57, "module-graphnet.models.gnn.dynedge_jinst"], [58, "module-graphnet.models.gnn.dynedge_kaggle_tito"], [59, "module-graphnet.models.gnn.gnn"], [60, "module-graphnet.models.graphs"], [61, "module-graphnet.models.graphs.edges"], [62, "module-graphnet.models.graphs.edges.edges"], [63, "module-graphnet.models.graphs.graph_definition"], [64, "module-graphnet.models.graphs.graphs"], [65, "module-graphnet.models.graphs.nodes"], [66, "module-graphnet.models.graphs.nodes.nodes"], [67, "module-graphnet.models.model"], [68, "module-graphnet.models.standard_model"], [69, "module-graphnet.models.task"], [70, "module-graphnet.models.task.classification"], [71, "module-graphnet.models.task.reconstruction"], [72, "module-graphnet.models.task.task"], [73, "module-graphnet.models.utils"], [74, "module-graphnet.pisa"], [75, "module-graphnet.pisa.fitting"], [76, "module-graphnet.pisa.plotting"], [77, "module-graphnet.training"], [78, "module-graphnet.training.callbacks"], [79, "module-graphnet.training.labels"], [80, "module-graphnet.training.loss_functions"], [81, "module-graphnet.training.utils"], [82, "module-graphnet.training.weight_fitting"], [83, "module-graphnet.utilities"], [84, "module-graphnet.utilities.argparse"], [85, "module-graphnet.utilities.config"], [86, "module-graphnet.utilities.config.base_config"], [87, "module-graphnet.utilities.config.configurable"], [88, "module-graphnet.utilities.config.dataset_config"], [89, "module-graphnet.utilities.config.model_config"], [90, "module-graphnet.utilities.config.parsing"], [91, "module-graphnet.utilities.config.training_config"], [92, "module-graphnet.utilities.decorators"], [93, "module-graphnet.utilities.filesys"], [94, "module-graphnet.utilities.imports"], [95, "module-graphnet.utilities.logging"], [96, "module-graphnet.utilities.maths"]], "graphnet.constants": [[2, "module-graphnet.constants"]], "graphnet.data": [[3, "module-graphnet.data"]], "deepcore (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.DEEPCORE"]], "deepcore (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.DEEPCORE"]], "features (class in graphnet.data.constants)": [[4, "graphnet.data.constants.FEATURES"]], "icecube86 (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.ICECUBE86"]], "icecube86 (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.ICECUBE86"]], "kaggle (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.KAGGLE"]], "kaggle (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.KAGGLE"]], "prometheus (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.PROMETHEUS"]], "prometheus (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.PROMETHEUS"]], "truth (class in graphnet.data.constants)": [[4, "graphnet.data.constants.TRUTH"]], "upgrade (graphnet.data.constants.features attribute)": [[4, "graphnet.data.constants.FEATURES.UPGRADE"]], "upgrade (graphnet.data.constants.truth attribute)": [[4, "graphnet.data.constants.TRUTH.UPGRADE"]], "graphnet.data.constants": [[4, "module-graphnet.data.constants"]], "dataconverter (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.DataConverter"]], "fileset (class in graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.FileSet"]], "cache_output_files() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.cache_output_files"]], "execute() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.execute"]], "file_suffix (graphnet.data.dataconverter.dataconverter property)": [[5, "graphnet.data.dataconverter.DataConverter.file_suffix"]], "gcd_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.gcd_file"]], "get_map_function() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.get_map_function"]], "graphnet.data.dataconverter": [[5, "module-graphnet.data.dataconverter"]], "i3_file (graphnet.data.dataconverter.fileset attribute)": [[5, "graphnet.data.dataconverter.FileSet.i3_file"]], "init_global_index() (in module graphnet.data.dataconverter)": [[5, "graphnet.data.dataconverter.init_global_index"]], "merge_files() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.merge_files"]], "save_data() (graphnet.data.dataconverter.dataconverter method)": [[5, "graphnet.data.dataconverter.DataConverter.save_data"]], "dataloader (class in graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.DataLoader"]], "collate_fn() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.collate_fn"]], "do_shuffle() (in module graphnet.data.dataloader)": [[6, "graphnet.data.dataloader.do_shuffle"]], "from_dataset_config() (graphnet.data.dataloader.dataloader class method)": [[6, "graphnet.data.dataloader.DataLoader.from_dataset_config"]], "graphnet.data.dataloader": [[6, "module-graphnet.data.dataloader"]], "graphnet.data.dataset": [[7, "module-graphnet.data.dataset"]], "columnmissingexception": [[8, "graphnet.data.dataset.dataset.ColumnMissingException"]], "dataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.Dataset"]], "ensembledataset (class in graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.EnsembleDataset"]], "add_label() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.add_label"]], "concatenate() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.concatenate"]], "from_config() (graphnet.data.dataset.dataset.dataset class method)": [[8, "graphnet.data.dataset.dataset.Dataset.from_config"]], "graphnet.data.dataset.dataset": [[8, "module-graphnet.data.dataset.dataset"]], "load_module() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.load_module"]], "parse_graph_definition() (in module graphnet.data.dataset.dataset)": [[8, "graphnet.data.dataset.dataset.parse_graph_definition"]], "path (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.path"]], "query_table() (graphnet.data.dataset.dataset.dataset method)": [[8, "graphnet.data.dataset.dataset.Dataset.query_table"]], "truth_table (graphnet.data.dataset.dataset.dataset property)": [[8, "graphnet.data.dataset.dataset.Dataset.truth_table"]], "graphnet.data.dataset.parquet": [[9, "module-graphnet.data.dataset.parquet"]], "parquetdataset (class in graphnet.data.dataset.parquet.parquet_dataset)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset"]], "graphnet.data.dataset.parquet.parquet_dataset": [[10, "module-graphnet.data.dataset.parquet.parquet_dataset"]], "query_table() (graphnet.data.dataset.parquet.parquet_dataset.parquetdataset method)": [[10, "graphnet.data.dataset.parquet.parquet_dataset.ParquetDataset.query_table"]], "graphnet.data.dataset.sqlite": [[11, "module-graphnet.data.dataset.sqlite"]], "sqlitedataset (class in graphnet.data.dataset.sqlite.sqlite_dataset)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset"]], "graphnet.data.dataset.sqlite.sqlite_dataset": [[12, "module-graphnet.data.dataset.sqlite.sqlite_dataset"]], "query_table() (graphnet.data.dataset.sqlite.sqlite_dataset.sqlitedataset method)": [[12, "graphnet.data.dataset.sqlite.sqlite_dataset.SQLiteDataset.query_table"]], "sqlitedatasetperturbed (class in graphnet.data.dataset.sqlite.sqlite_dataset_perturbed)": [[13, "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.SQLiteDatasetPerturbed"]], "graphnet.data.dataset.sqlite.sqlite_dataset_perturbed": [[13, "module-graphnet.data.dataset.sqlite.sqlite_dataset_perturbed"]], "graphnet.data.extractors": [[14, "module-graphnet.data.extractors"]], "i3extractor (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor"]], "i3extractorcollection (class in graphnet.data.extractors.i3extractor)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection"]], "graphnet.data.extractors.i3extractor": [[15, "module-graphnet.data.extractors.i3extractor"]], "name (graphnet.data.extractors.i3extractor.i3extractor property)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.name"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractor method)": [[15, "graphnet.data.extractors.i3extractor.I3Extractor.set_files"]], "set_files() (graphnet.data.extractors.i3extractor.i3extractorcollection method)": [[15, "graphnet.data.extractors.i3extractor.I3ExtractorCollection.set_files"]], "i3featureextractor (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractor"]], "i3featureextractoricecube86 (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCube86"]], "i3featureextractoricecubedeepcore (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeDeepCore"]], "i3featureextractoricecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3FeatureExtractorIceCubeUpgrade"]], "i3pulsenoisetruthflagicecubeupgrade (class in graphnet.data.extractors.i3featureextractor)": [[16, "graphnet.data.extractors.i3featureextractor.I3PulseNoiseTruthFlagIceCubeUpgrade"]], "graphnet.data.extractors.i3featureextractor": [[16, "module-graphnet.data.extractors.i3featureextractor"]], "i3genericextractor (class in graphnet.data.extractors.i3genericextractor)": [[17, "graphnet.data.extractors.i3genericextractor.I3GenericExtractor"]], "graphnet.data.extractors.i3genericextractor": [[17, "module-graphnet.data.extractors.i3genericextractor"]], "i3galacticplanehybridrecoextractor (class in graphnet.data.extractors.i3hybridrecoextractor)": [[18, "graphnet.data.extractors.i3hybridrecoextractor.I3GalacticPlaneHybridRecoExtractor"]], "graphnet.data.extractors.i3hybridrecoextractor": [[18, "module-graphnet.data.extractors.i3hybridrecoextractor"]], "i3ntmuonlabelextractor (class in graphnet.data.extractors.i3ntmuonlabelsextractor)": [[19, "graphnet.data.extractors.i3ntmuonlabelsextractor.I3NTMuonLabelExtractor"]], "graphnet.data.extractors.i3ntmuonlabelsextractor": [[19, "module-graphnet.data.extractors.i3ntmuonlabelsextractor"]], "i3particleextractor (class in graphnet.data.extractors.i3particleextractor)": [[20, "graphnet.data.extractors.i3particleextractor.I3ParticleExtractor"]], "graphnet.data.extractors.i3particleextractor": [[20, "module-graphnet.data.extractors.i3particleextractor"]], "i3pisaextractor (class in graphnet.data.extractors.i3pisaextractor)": [[21, "graphnet.data.extractors.i3pisaextractor.I3PISAExtractor"]], "graphnet.data.extractors.i3pisaextractor": [[21, "module-graphnet.data.extractors.i3pisaextractor"]], "i3quesoextractor (class in graphnet.data.extractors.i3quesoextractor)": [[22, "graphnet.data.extractors.i3quesoextractor.I3QUESOExtractor"]], "graphnet.data.extractors.i3quesoextractor": [[22, "module-graphnet.data.extractors.i3quesoextractor"]], "i3retroextractor (class in graphnet.data.extractors.i3retroextractor)": [[23, "graphnet.data.extractors.i3retroextractor.I3RetroExtractor"]], "graphnet.data.extractors.i3retroextractor": [[23, "module-graphnet.data.extractors.i3retroextractor"]], "i3splinempeicextractor (class in graphnet.data.extractors.i3splinempeextractor)": [[24, "graphnet.data.extractors.i3splinempeextractor.I3SplineMPEICExtractor"]], "graphnet.data.extractors.i3splinempeextractor": [[24, "module-graphnet.data.extractors.i3splinempeextractor"]], "i3truthextractor (class in graphnet.data.extractors.i3truthextractor)": [[25, "graphnet.data.extractors.i3truthextractor.I3TruthExtractor"]], "graphnet.data.extractors.i3truthextractor": [[25, "module-graphnet.data.extractors.i3truthextractor"]], "i3tumextractor (class in graphnet.data.extractors.i3tumextractor)": [[26, "graphnet.data.extractors.i3tumextractor.I3TUMExtractor"]], "graphnet.data.extractors.i3tumextractor": [[26, "module-graphnet.data.extractors.i3tumextractor"]], "graphnet.data.extractors.utilities": [[27, "module-graphnet.data.extractors.utilities"]], "flatten_nested_dictionary() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.flatten_nested_dictionary"]], "graphnet.data.extractors.utilities.collections": [[28, "module-graphnet.data.extractors.utilities.collections"]], "serialise() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.serialise"]], "transpose_list_of_dicts() (in module graphnet.data.extractors.utilities.collections)": [[28, "graphnet.data.extractors.utilities.collections.transpose_list_of_dicts"]], "frame_is_montecarlo() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_montecarlo"]], "frame_is_noise() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.frame_is_noise"]], "get_om_keys_and_pulseseries() (in module graphnet.data.extractors.utilities.frames)": [[29, "graphnet.data.extractors.utilities.frames.get_om_keys_and_pulseseries"]], "graphnet.data.extractors.utilities.frames": [[29, "module-graphnet.data.extractors.utilities.frames"]], "break_cyclic_recursion() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.break_cyclic_recursion"]], "cast_object_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_object_to_pure_python"]], "cast_pulse_series_to_pure_python() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.cast_pulse_series_to_pure_python"]], "get_member_variables() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.get_member_variables"]], "graphnet.data.extractors.utilities.types": [[30, "module-graphnet.data.extractors.utilities.types"]], "is_boost_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_class"]], "is_boost_enum() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_boost_enum"]], "is_icecube_class() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_icecube_class"]], "is_method() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_method"]], "is_type() (in module graphnet.data.extractors.utilities.types)": [[30, "graphnet.data.extractors.utilities.types.is_type"]], "graphnet.data.parquet": [[31, "module-graphnet.data.parquet"]], "parquetdataconverter (class in graphnet.data.parquet.parquet_dataconverter)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter"]], "file_suffix (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter attribute)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.file_suffix"]], "graphnet.data.parquet.parquet_dataconverter": [[32, "module-graphnet.data.parquet.parquet_dataconverter"]], "merge_files() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.merge_files"]], "save_data() (graphnet.data.parquet.parquet_dataconverter.parquetdataconverter method)": [[32, "graphnet.data.parquet.parquet_dataconverter.ParquetDataConverter.save_data"]], "insqlitepipeline (class in graphnet.data.pipeline)": [[33, "graphnet.data.pipeline.InSQLitePipeline"]], "graphnet.data.pipeline": [[33, "module-graphnet.data.pipeline"]], "graphnet.data.sqlite": [[34, "module-graphnet.data.sqlite"]], "sqlitedataconverter (class in graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter"]], "any_pulsemap_is_non_empty() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.any_pulsemap_is_non_empty"]], "construct_dataframe() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.construct_dataframe"]], "file_suffix (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter attribute)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.file_suffix"]], "graphnet.data.sqlite.sqlite_dataconverter": [[35, "module-graphnet.data.sqlite.sqlite_dataconverter"]], "is_mc_tree() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_mc_tree"]], "is_pulse_map() (in module graphnet.data.sqlite.sqlite_dataconverter)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.is_pulse_map"]], "merge_files() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.merge_files"]], "save_data() (graphnet.data.sqlite.sqlite_dataconverter.sqlitedataconverter method)": [[35, "graphnet.data.sqlite.sqlite_dataconverter.SQLiteDataConverter.save_data"]], "attach_index() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.attach_index"]], "create_table() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table"]], "create_table_and_save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.create_table_and_save_to_sql"]], "database_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_exists"]], "database_table_exists() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.database_table_exists"]], "graphnet.data.sqlite.sqlite_utilities": [[36, "module-graphnet.data.sqlite.sqlite_utilities"]], "run_sql_code() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.run_sql_code"]], "save_to_sql() (in module graphnet.data.sqlite.sqlite_utilities)": [[36, "graphnet.data.sqlite.sqlite_utilities.save_to_sql"]], "graphnet.data.utilities": [[37, "module-graphnet.data.utilities"]], "parquettosqliteconverter (class in graphnet.data.utilities.parquet_to_sqlite)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter"]], "graphnet.data.utilities.parquet_to_sqlite": [[38, "module-graphnet.data.utilities.parquet_to_sqlite"]], "run() (graphnet.data.utilities.parquet_to_sqlite.parquettosqliteconverter method)": [[38, "graphnet.data.utilities.parquet_to_sqlite.ParquetToSQLiteConverter.run"]], "graphnet.data.utilities.random": [[39, "module-graphnet.data.utilities.random"]], "pairwise_shuffle() (in module graphnet.data.utilities.random)": [[39, "graphnet.data.utilities.random.pairwise_shuffle"]], "stringselectionresolver (class in graphnet.data.utilities.string_selection_resolver)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver"]], "graphnet.data.utilities.string_selection_resolver": [[40, "module-graphnet.data.utilities.string_selection_resolver"]], "resolve() (graphnet.data.utilities.string_selection_resolver.stringselectionresolver method)": [[40, "graphnet.data.utilities.string_selection_resolver.StringSelectionResolver.resolve"]], "graphnet.deployment": [[41, "module-graphnet.deployment"]], "graphneti3module (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.GraphNeTI3Module"]], "i3inferencemodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3InferenceModule"]], "i3pulsecleanermodule (class in graphnet.deployment.i3modules.graphnet_module)": [[44, "graphnet.deployment.i3modules.graphnet_module.I3PulseCleanerModule"]], "graphnet.deployment.i3modules.graphnet_module": [[44, "module-graphnet.deployment.i3modules.graphnet_module"]], "graphnet.models": [[45, "module-graphnet.models"]], "attributecoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.AttributeCoarsening"]], "coarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.Coarsening"]], "customdomcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.CustomDOMCoarsening"]], "domandtimewindowcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMAndTimeWindowCoarsening"]], "domcoarsening (class in graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.DOMCoarsening"]], "forward() (graphnet.models.coarsening.coarsening method)": [[46, "graphnet.models.coarsening.Coarsening.forward"]], "graphnet.models.coarsening": [[46, "module-graphnet.models.coarsening"]], "reduce_options (graphnet.models.coarsening.coarsening attribute)": [[46, "graphnet.models.coarsening.Coarsening.reduce_options"]], "unbatch_edge_index() (in module graphnet.models.coarsening)": [[46, "graphnet.models.coarsening.unbatch_edge_index"]], "graphnet.models.components": [[47, "module-graphnet.models.components"]], "dynedgeconv (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynEdgeConv"]], "dyntrans (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.DynTrans"]], "edgeconvtito (class in graphnet.models.components.layers)": [[48, "graphnet.models.components.layers.EdgeConvTito"]], "forward() (graphnet.models.components.layers.dynedgeconv method)": [[48, "graphnet.models.components.layers.DynEdgeConv.forward"]], "forward() (graphnet.models.components.layers.dyntrans method)": [[48, "graphnet.models.components.layers.DynTrans.forward"]], "forward() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.forward"]], "graphnet.models.components.layers": [[48, "module-graphnet.models.components.layers"]], "message() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.message"]], "reset_parameters() (graphnet.models.components.layers.edgeconvtito method)": [[48, "graphnet.models.components.layers.EdgeConvTito.reset_parameters"]], "graphnet.models.components.pool": [[49, "module-graphnet.models.components.pool"]], "group_by() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_by"]], "group_pulses_to_dom() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_dom"]], "group_pulses_to_pmt() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.group_pulses_to_pmt"]], "min_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool"]], "min_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.min_pool_x"]], "std_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool"]], "std_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.std_pool_x"]], "sum_pool() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool"]], "sum_pool_and_distribute() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_and_distribute"]], "sum_pool_x() (in module graphnet.models.components.pool)": [[49, "graphnet.models.components.pool.sum_pool_x"]], "graphnet.models.detector": [[50, "module-graphnet.models.detector"]], "detector (class in graphnet.models.detector.detector)": [[51, "graphnet.models.detector.detector.Detector"]], "feature_map() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.feature_map"]], "forward() (graphnet.models.detector.detector.detector method)": [[51, "graphnet.models.detector.detector.Detector.forward"]], "graphnet.models.detector.detector": [[51, "module-graphnet.models.detector.detector"]], "icecube86 (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCube86"]], "icecubedeepcore (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore"]], "icecubekaggle (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle"]], "icecubeupgrade (class in graphnet.models.detector.icecube)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade"]], "feature_map() (graphnet.models.detector.icecube.icecube86 method)": [[52, "graphnet.models.detector.icecube.IceCube86.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubedeepcore method)": [[52, "graphnet.models.detector.icecube.IceCubeDeepCore.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubekaggle method)": [[52, "graphnet.models.detector.icecube.IceCubeKaggle.feature_map"]], "feature_map() (graphnet.models.detector.icecube.icecubeupgrade method)": [[52, "graphnet.models.detector.icecube.IceCubeUpgrade.feature_map"]], "graphnet.models.detector.icecube": [[52, "module-graphnet.models.detector.icecube"]], "prometheus (class in graphnet.models.detector.prometheus)": [[53, "graphnet.models.detector.prometheus.Prometheus"]], "feature_map() (graphnet.models.detector.prometheus.prometheus method)": [[53, "graphnet.models.detector.prometheus.Prometheus.feature_map"]], "graphnet.models.detector.prometheus": [[53, "module-graphnet.models.detector.prometheus"]], "graphnet.models.gnn": [[54, "module-graphnet.models.gnn"]], "convnet (class in graphnet.models.gnn.convnet)": [[55, "graphnet.models.gnn.convnet.ConvNet"]], "forward() (graphnet.models.gnn.convnet.convnet method)": [[55, "graphnet.models.gnn.convnet.ConvNet.forward"]], "graphnet.models.gnn.convnet": [[55, "module-graphnet.models.gnn.convnet"]], "dynedge (class in graphnet.models.gnn.dynedge)": [[56, "graphnet.models.gnn.dynedge.DynEdge"]], "forward() (graphnet.models.gnn.dynedge.dynedge method)": [[56, "graphnet.models.gnn.dynedge.DynEdge.forward"]], "graphnet.models.gnn.dynedge": [[56, "module-graphnet.models.gnn.dynedge"]], "dynedgejinst (class in graphnet.models.gnn.dynedge_jinst)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST"]], "forward() (graphnet.models.gnn.dynedge_jinst.dynedgejinst method)": [[57, "graphnet.models.gnn.dynedge_jinst.DynEdgeJINST.forward"]], "graphnet.models.gnn.dynedge_jinst": [[57, "module-graphnet.models.gnn.dynedge_jinst"]], "dynedgetito (class in graphnet.models.gnn.dynedge_kaggle_tito)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO"]], "forward() (graphnet.models.gnn.dynedge_kaggle_tito.dynedgetito method)": [[58, "graphnet.models.gnn.dynedge_kaggle_tito.DynEdgeTITO.forward"]], "graphnet.models.gnn.dynedge_kaggle_tito": [[58, "module-graphnet.models.gnn.dynedge_kaggle_tito"]], "gnn (class in graphnet.models.gnn.gnn)": [[59, "graphnet.models.gnn.gnn.GNN"]], "forward() (graphnet.models.gnn.gnn.gnn method)": [[59, "graphnet.models.gnn.gnn.GNN.forward"]], "graphnet.models.gnn.gnn": [[59, "module-graphnet.models.gnn.gnn"]], "nb_inputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_inputs"]], "nb_outputs (graphnet.models.gnn.gnn.gnn property)": [[59, "graphnet.models.gnn.gnn.GNN.nb_outputs"]], "graphnet.models.graphs": [[60, "module-graphnet.models.graphs"]], "graphnet.models.graphs.edges": [[61, "module-graphnet.models.graphs.edges"]], "edgedefinition (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition"]], "euclideanedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.EuclideanEdges"]], "knnedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.KNNEdges"]], "radialedges (class in graphnet.models.graphs.edges.edges)": [[62, "graphnet.models.graphs.edges.edges.RadialEdges"]], "forward() (graphnet.models.graphs.edges.edges.edgedefinition method)": [[62, "graphnet.models.graphs.edges.edges.EdgeDefinition.forward"]], "graphnet.models.graphs.edges.edges": [[62, "module-graphnet.models.graphs.edges.edges"]], "graphdefinition (class in graphnet.models.graphs.graph_definition)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition"]], "forward() (graphnet.models.graphs.graph_definition.graphdefinition method)": [[63, "graphnet.models.graphs.graph_definition.GraphDefinition.forward"]], "graphnet.models.graphs.graph_definition": [[63, "module-graphnet.models.graphs.graph_definition"]], "knngraph (class in graphnet.models.graphs.graphs)": [[64, "graphnet.models.graphs.graphs.KNNGraph"]], "graphnet.models.graphs.graphs": [[64, "module-graphnet.models.graphs.graphs"]], "graphnet.models.graphs.nodes": [[65, "module-graphnet.models.graphs.nodes"]], "nodedefinition (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition"]], "nodesaspulses (class in graphnet.models.graphs.nodes.nodes)": [[66, "graphnet.models.graphs.nodes.nodes.NodesAsPulses"]], "forward() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.forward"]], "graphnet.models.graphs.nodes.nodes": [[66, "module-graphnet.models.graphs.nodes.nodes"]], "nb_outputs (graphnet.models.graphs.nodes.nodes.nodedefinition property)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.nb_outputs"]], "set_number_of_inputs() (graphnet.models.graphs.nodes.nodes.nodedefinition method)": [[66, "graphnet.models.graphs.nodes.nodes.NodeDefinition.set_number_of_inputs"]], "model (class in graphnet.models.model)": [[67, "graphnet.models.model.Model"]], "fit() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.fit"]], "forward() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.forward"]], "from_config() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.from_config"]], "graphnet.models.model": [[67, "module-graphnet.models.model"]], "load() (graphnet.models.model.model class method)": [[67, "graphnet.models.model.Model.load"]], "load_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.load_state_dict"]], "predict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict"]], "predict_as_dataframe() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.predict_as_dataframe"]], "save() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save"]], "save_state_dict() (graphnet.models.model.model method)": [[67, "graphnet.models.model.Model.save_state_dict"]], "standardmodel (class in graphnet.models.standard_model)": [[68, "graphnet.models.standard_model.StandardModel"]], "compute_loss() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.compute_loss"]], "configure_optimizers() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.configure_optimizers"]], "forward() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.forward"]], "graphnet.models.standard_model": [[68, "module-graphnet.models.standard_model"]], "inference() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.inference"]], "predict() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict"]], "predict_as_dataframe() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.predict_as_dataframe"]], "prediction_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.prediction_labels"]], "shared_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.shared_step"]], "target_labels (graphnet.models.standard_model.standardmodel property)": [[68, "graphnet.models.standard_model.StandardModel.target_labels"]], "train() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.train"]], "training_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.training_step"]], "validation_step() (graphnet.models.standard_model.standardmodel method)": [[68, "graphnet.models.standard_model.StandardModel.validation_step"]], "graphnet.models.task": [[69, "module-graphnet.models.task"]], "binaryclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTask"]], "binaryclassificationtasklogits (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits"]], "multiclassclassificationtask (class in graphnet.models.task.classification)": [[70, "graphnet.models.task.classification.MulticlassClassificationTask"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_prediction_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.default_target_labels"]], "default_target_labels (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.default_target_labels"]], "graphnet.models.task.classification": [[70, "module-graphnet.models.task.classification"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtask attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTask.nb_inputs"]], "nb_inputs (graphnet.models.task.classification.binaryclassificationtasklogits attribute)": [[70, "graphnet.models.task.classification.BinaryClassificationTaskLogits.nb_inputs"]], "azimuthreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction"]], "azimuthreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa"]], "directionreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa"]], "energyreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction"]], "energyreconstructionwithpower (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower"]], "energyreconstructionwithuncertainty (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty"]], "inelasticityreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction"]], "positionreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction"]], "timereconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction"]], "vertexreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction"]], "zenithreconstruction (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction"]], "zenithreconstructionwithkappa (class in graphnet.models.task.reconstruction)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_prediction_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.default_target_labels"]], "default_target_labels (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.default_target_labels"]], "graphnet.models.task.reconstruction": [[71, "module-graphnet.models.task.reconstruction"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.azimuthreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.AzimuthReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.directionreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.DirectionReconstructionWithKappa.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithpower attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithPower.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.energyreconstructionwithuncertainty attribute)": [[71, "graphnet.models.task.reconstruction.EnergyReconstructionWithUncertainty.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.inelasticityreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.InelasticityReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.positionreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.PositionReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.timereconstruction attribute)": [[71, "graphnet.models.task.reconstruction.TimeReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.vertexreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.VertexReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstruction attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstruction.nb_inputs"]], "nb_inputs (graphnet.models.task.reconstruction.zenithreconstructionwithkappa attribute)": [[71, "graphnet.models.task.reconstruction.ZenithReconstructionWithKappa.nb_inputs"]], "identitytask (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.IdentityTask"]], "task (class in graphnet.models.task.task)": [[72, "graphnet.models.task.task.Task"]], "compute_loss() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.compute_loss"]], "default_prediction_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_prediction_labels"]], "default_prediction_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_prediction_labels"]], "default_target_labels (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.default_target_labels"]], "default_target_labels (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.default_target_labels"]], "forward() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.forward"]], "graphnet.models.task.task": [[72, "module-graphnet.models.task.task"]], "inference() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.inference"]], "nb_inputs (graphnet.models.task.task.identitytask property)": [[72, "graphnet.models.task.task.IdentityTask.nb_inputs"]], "nb_inputs (graphnet.models.task.task.task property)": [[72, "graphnet.models.task.task.Task.nb_inputs"]], "train_eval() (graphnet.models.task.task.task method)": [[72, "graphnet.models.task.task.Task.train_eval"]], "calculate_distance_matrix() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_distance_matrix"]], "calculate_xyzt_homophily() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.calculate_xyzt_homophily"]], "graphnet.models.utils": [[73, "module-graphnet.models.utils"]], "knn_graph_batch() (in module graphnet.models.utils)": [[73, "graphnet.models.utils.knn_graph_batch"]], "graphnet.pisa": [[74, "module-graphnet.pisa"]], "contourfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.ContourFitter"]], "weightfitter (class in graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.WeightFitter"]], "config_updater() (in module graphnet.pisa.fitting)": [[75, "graphnet.pisa.fitting.config_updater"]], "fit_1d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_1d_contour"]], "fit_2d_contour() (graphnet.pisa.fitting.contourfitter method)": [[75, "graphnet.pisa.fitting.ContourFitter.fit_2d_contour"]], "fit_weights() (graphnet.pisa.fitting.weightfitter method)": [[75, "graphnet.pisa.fitting.WeightFitter.fit_weights"]], "graphnet.pisa.fitting": [[75, "module-graphnet.pisa.fitting"]], "graphnet.pisa.plotting": [[76, "module-graphnet.pisa.plotting"]], "plot_1d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_1D_contour"]], "plot_2d_contour() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.plot_2D_contour"]], "read_entry() (in module graphnet.pisa.plotting)": [[76, "graphnet.pisa.plotting.read_entry"]], "graphnet.training": [[77, "module-graphnet.training"]], "piecewiselinearlr (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR"]], "progressbar (class in graphnet.training.callbacks)": [[78, "graphnet.training.callbacks.ProgressBar"]], "get_lr() (graphnet.training.callbacks.piecewiselinearlr method)": [[78, "graphnet.training.callbacks.PiecewiseLinearLR.get_lr"]], "get_metrics() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.get_metrics"]], "graphnet.training.callbacks": [[78, "module-graphnet.training.callbacks"]], "init_predict_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_predict_tqdm"]], "init_test_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_test_tqdm"]], "init_train_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_train_tqdm"]], "init_validation_tqdm() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.init_validation_tqdm"]], "on_train_epoch_end() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_end"]], "on_train_epoch_start() (graphnet.training.callbacks.progressbar method)": [[78, "graphnet.training.callbacks.ProgressBar.on_train_epoch_start"]], "direction (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Direction"]], "label (class in graphnet.training.labels)": [[79, "graphnet.training.labels.Label"]], "graphnet.training.labels": [[79, "module-graphnet.training.labels"]], "key (graphnet.training.labels.label property)": [[79, "graphnet.training.labels.Label.key"]], "binarycrossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.BinaryCrossEntropyLoss"]], "crossentropyloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.CrossEntropyLoss"]], "euclideandistanceloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.EuclideanDistanceLoss"]], "logcmk (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCMK"]], "logcoshloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LogCoshLoss"]], "lossfunction (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.LossFunction"]], "mseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.MSELoss"]], "rmseloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.RMSELoss"]], "vonmisesfisher2dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher2DLoss"]], "vonmisesfisher3dloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisher3DLoss"]], "vonmisesfisherloss (class in graphnet.training.loss_functions)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss"]], "backward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.backward"]], "forward() (graphnet.training.loss_functions.logcmk static method)": [[80, "graphnet.training.loss_functions.LogCMK.forward"]], "forward() (graphnet.training.loss_functions.lossfunction method)": [[80, "graphnet.training.loss_functions.LossFunction.forward"]], "graphnet.training.loss_functions": [[80, "module-graphnet.training.loss_functions"]], "log_cmk() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk"]], "log_cmk_approx() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_approx"]], "log_cmk_exact() (graphnet.training.loss_functions.vonmisesfisherloss class method)": [[80, "graphnet.training.loss_functions.VonMisesFisherLoss.log_cmk_exact"]], "collate_fn() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.collate_fn"]], "get_predictions() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.get_predictions"]], "graphnet.training.utils": [[81, "module-graphnet.training.utils"]], "make_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_dataloader"]], "make_train_validation_dataloader() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.make_train_validation_dataloader"]], "save_results() (in module graphnet.training.utils)": [[81, "graphnet.training.utils.save_results"]], "bjoernlow (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.BjoernLow"]], "uniform (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.Uniform"]], "weightfitter (class in graphnet.training.weight_fitting)": [[82, "graphnet.training.weight_fitting.WeightFitter"]], "fit() (graphnet.training.weight_fitting.weightfitter method)": [[82, "graphnet.training.weight_fitting.WeightFitter.fit"]], "graphnet.training.weight_fitting": [[82, "module-graphnet.training.weight_fitting"]], "graphnet.utilities": [[83, "module-graphnet.utilities"]], "argumentparser (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.ArgumentParser"]], "options (class in graphnet.utilities.argparse)": [[84, "graphnet.utilities.argparse.Options"]], "contains() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.contains"]], "graphnet.utilities.argparse": [[84, "module-graphnet.utilities.argparse"]], "pop_default() (graphnet.utilities.argparse.options method)": [[84, "graphnet.utilities.argparse.Options.pop_default"]], "standard_arguments (graphnet.utilities.argparse.argumentparser attribute)": [[84, "graphnet.utilities.argparse.ArgumentParser.standard_arguments"]], "with_standard_arguments() (graphnet.utilities.argparse.argumentparser method)": [[84, "graphnet.utilities.argparse.ArgumentParser.with_standard_arguments"]], "graphnet.utilities.config": [[85, "module-graphnet.utilities.config"]], "baseconfig (class in graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.BaseConfig"]], "as_dict() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.as_dict"]], "dump() (graphnet.utilities.config.base_config.baseconfig method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.dump"]], "get_all_argument_values() (in module graphnet.utilities.config.base_config)": [[86, "graphnet.utilities.config.base_config.get_all_argument_values"]], "graphnet.utilities.config.base_config": [[86, "module-graphnet.utilities.config.base_config"]], "load() (graphnet.utilities.config.base_config.baseconfig class method)": [[86, "graphnet.utilities.config.base_config.BaseConfig.load"]], "model_config (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_config"]], "model_fields (graphnet.utilities.config.base_config.baseconfig attribute)": [[86, "graphnet.utilities.config.base_config.BaseConfig.model_fields"]], "configurable (class in graphnet.utilities.config.configurable)": [[87, "graphnet.utilities.config.configurable.Configurable"]], "config (graphnet.utilities.config.configurable.configurable property)": [[87, "graphnet.utilities.config.configurable.Configurable.config"]], "from_config() (graphnet.utilities.config.configurable.configurable class method)": [[87, "graphnet.utilities.config.configurable.Configurable.from_config"]], "graphnet.utilities.config.configurable": [[87, "module-graphnet.utilities.config.configurable"]], "save_config() (graphnet.utilities.config.configurable.configurable method)": [[87, "graphnet.utilities.config.configurable.Configurable.save_config"]], "datasetconfig (class in graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig"]], "as_dict() (graphnet.utilities.config.dataset_config.datasetconfig method)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.as_dict"]], "features (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.features"]], "graph_definition (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.graph_definition"]], "graphnet.utilities.config.dataset_config": [[88, "module-graphnet.utilities.config.dataset_config"]], "index_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.index_column"]], "loss_weight_column (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_column"]], "loss_weight_default_value (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_default_value"]], "loss_weight_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.loss_weight_table"]], "model_config (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_config"]], "model_fields (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.model_fields"]], "node_truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth"]], "node_truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.node_truth_table"]], "path (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.path"]], "pulsemaps (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.pulsemaps"]], "save_dataset_config() (in module graphnet.utilities.config.dataset_config)": [[88, "graphnet.utilities.config.dataset_config.save_dataset_config"]], "seed (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.seed"]], "selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.selection"]], "string_selection (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.string_selection"]], "truth (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth"]], "truth_table (graphnet.utilities.config.dataset_config.datasetconfig attribute)": [[88, "graphnet.utilities.config.dataset_config.DatasetConfig.truth_table"]], "modelconfig (class in graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.ModelConfig"]], "arguments (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.arguments"]], "as_dict() (graphnet.utilities.config.model_config.modelconfig method)": [[89, "graphnet.utilities.config.model_config.ModelConfig.as_dict"]], "class_name (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.class_name"]], "graphnet.utilities.config.model_config": [[89, "module-graphnet.utilities.config.model_config"]], "model_config (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_config"]], "model_fields (graphnet.utilities.config.model_config.modelconfig attribute)": [[89, "graphnet.utilities.config.model_config.ModelConfig.model_fields"]], "save_model_config() (in module graphnet.utilities.config.model_config)": [[89, "graphnet.utilities.config.model_config.save_model_config"]], "get_all_grapnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_all_grapnet_classes"]], "get_graphnet_classes() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.get_graphnet_classes"]], "graphnet.utilities.config.parsing": [[90, "module-graphnet.utilities.config.parsing"]], "is_graphnet_class() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_class"]], "is_graphnet_module() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.is_graphnet_module"]], "list_all_submodules() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.list_all_submodules"]], "traverse_and_apply() (in module graphnet.utilities.config.parsing)": [[90, "graphnet.utilities.config.parsing.traverse_and_apply"]], "trainingconfig (class in graphnet.utilities.config.training_config)": [[91, "graphnet.utilities.config.training_config.TrainingConfig"]], "dataloader (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.dataloader"]], "early_stopping_patience (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.early_stopping_patience"]], "fit (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.fit"]], "graphnet.utilities.config.training_config": [[91, "module-graphnet.utilities.config.training_config"]], "model_config (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_config"]], "model_fields (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.model_fields"]], "target (graphnet.utilities.config.training_config.trainingconfig attribute)": [[91, "graphnet.utilities.config.training_config.TrainingConfig.target"]], "graphnet.utilities.decorators": [[92, "module-graphnet.utilities.decorators"]], "find_i3_files() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.find_i3_files"]], "graphnet.utilities.filesys": [[93, "module-graphnet.utilities.filesys"]], "has_extension() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.has_extension"]], "is_gcd_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_gcd_file"]], "is_i3_file() (in module graphnet.utilities.filesys)": [[93, "graphnet.utilities.filesys.is_i3_file"]], "graphnet.utilities.imports": [[94, "module-graphnet.utilities.imports"]], "has_icecube_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_icecube_package"]], "has_pisa_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_pisa_package"]], "has_torch_package() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.has_torch_package"]], "requires_icecube() (in module graphnet.utilities.imports)": [[94, "graphnet.utilities.imports.requires_icecube"]], "logger (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.Logger"]], "repeatfilter (class in graphnet.utilities.logging)": [[95, "graphnet.utilities.logging.RepeatFilter"]], "critical() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.critical"]], "debug() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.debug"]], "error() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.error"]], "file_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.file_handlers"]], "filter() (graphnet.utilities.logging.repeatfilter method)": [[95, "graphnet.utilities.logging.RepeatFilter.filter"]], "graphnet.utilities.logging": [[95, "module-graphnet.utilities.logging"]], "handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.handlers"]], "info() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.info"]], "nb_repeats_allowed (graphnet.utilities.logging.repeatfilter attribute)": [[95, "graphnet.utilities.logging.RepeatFilter.nb_repeats_allowed"]], "setlevel() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.setLevel"]], "stream_handlers (graphnet.utilities.logging.logger property)": [[95, "graphnet.utilities.logging.Logger.stream_handlers"]], "warning() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning"]], "warning_once() (graphnet.utilities.logging.logger method)": [[95, "graphnet.utilities.logging.Logger.warning_once"]], "eps_like() (in module graphnet.utilities.maths)": [[96, "graphnet.utilities.maths.eps_like"]], "graphnet.utilities.maths": [[96, "module-graphnet.utilities.maths"]]}}) \ No newline at end of file diff --git a/sitemap.xml b/sitemap.xml index 8ca584144..1332adef5 100644 --- a/sitemap.xml +++ b/sitemap.xml @@ -1 +1 @@ -https://graphnet-team.github.io/graphnetabout.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.constants.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.constants.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.model.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.utils.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.labels.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.utils.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.htmlhttps://graphnet-team.github.io/graphnetapi/modules.htmlhttps://graphnet-team.github.io/graphnetcontribute.htmlhttps://graphnet-team.github.io/graphnetindex.htmlhttps://graphnet-team.github.io/graphnetinstall.htmlhttps://graphnet-team.github.io/graphnetgenindex.htmlhttps://graphnet-team.github.io/graphnetpy-modindex.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.htmlhttps://graphnet-team.github.io/graphnet_modules/index.htmlhttps://graphnet-team.github.io/graphnetsearch.html \ No newline at end of file +https://graphnet-team.github.io/graphnetabout.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.constants.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.constants.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataloader.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.parquet.parquet_dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.dataset.sqlite.sqlite_dataset_perturbed.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3extractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3featureextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3genericextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3hybridrecoextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3ntmuonlabelsextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3particleextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3pisaextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3quesoextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3retroextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3splinempeextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3truthextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.i3tumextractor.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.collections.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.frames.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.extractors.utilities.types.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.parquet.parquet_dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.pipeline.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_dataconverter.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.sqlite.sqlite_utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.parquet_to_sqlite.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.random.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.data.utilities.string_selection_resolver.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.deployer.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.deployment.i3modules.graphnet_module.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.coarsening.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.layers.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.components.pool.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.detector.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.icecube.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.detector.prometheus.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.convnet.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_jinst.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.dynedge_kaggle_tito.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.gnn.gnn.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.edges.edges.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graph_definition.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.graphs.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.graphs.nodes.nodes.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.model.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.standard_model.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.classification.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.reconstruction.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.task.task.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.models.utils.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.fitting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.pisa.plotting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.callbacks.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.labels.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.loss_functions.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.utils.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.training.weight_fitting.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.argparse.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.base_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.configurable.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.dataset_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.model_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.parsing.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.config.training_config.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.decorators.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.filesys.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.imports.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.logging.htmlhttps://graphnet-team.github.io/graphnetapi/graphnet.utilities.maths.htmlhttps://graphnet-team.github.io/graphnetapi/modules.htmlhttps://graphnet-team.github.io/graphnetcontribute.htmlhttps://graphnet-team.github.io/graphnetindex.htmlhttps://graphnet-team.github.io/graphnetinstall.htmlhttps://graphnet-team.github.io/graphnetgenindex.htmlhttps://graphnet-team.github.io/graphnetpy-modindex.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/constants.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataloader.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/dataset.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/parquet/parquet_dataset.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3extractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3featureextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3genericextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3hybridrecoextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3ntmuonlabelsextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3particleextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3pisaextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3quesoextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3retroextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3splinempeextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3truthextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/i3tumextractor.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/collections.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/frames.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/extractors/utilities/types.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/parquet/parquet_dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/pipeline.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_dataconverter.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/sqlite/sqlite_utilities.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/parquet_to_sqlite.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/random.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/data/utilities/string_selection_resolver.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/deployment/i3modules/graphnet_module.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/coarsening.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/components/layers.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/components/pool.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/detector.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/icecube.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/detector/prometheus.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/convnet.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_jinst.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/dynedge_kaggle_tito.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/gnn/gnn.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/edges/edges.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graph_definition.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/graphs.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/graphs/nodes/nodes.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/model.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/standard_model.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/task/classification.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/task/reconstruction.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/task/task.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/models/utils.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/pisa/fitting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/pisa/plotting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/callbacks.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/labels.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/loss_functions.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/utils.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/training/weight_fitting.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/argparse.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/base_config.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/configurable.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/dataset_config.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/model_config.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/parsing.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/config/training_config.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/filesys.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/imports.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/logging.htmlhttps://graphnet-team.github.io/graphnet_modules/graphnet/utilities/maths.htmlhttps://graphnet-team.github.io/graphnet_modules/index.htmlhttps://graphnet-team.github.io/graphnetsearch.html \ No newline at end of file