diff --git a/_modules/graphnet/data/datamodule.html b/_modules/graphnet/data/datamodule.html new file mode 100644 index 000000000..9813c0db0 --- /dev/null +++ b/_modules/graphnet/data/datamodule.html @@ -0,0 +1,828 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + graphnet.data.datamodule — graphnet documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Skip to content +
+ +
+ + +
+ + + + +
+
+ +
+
+
+ +
+
+
+
+
+
+ + +
+
+
+ +
+
+ +

Source code for graphnet.data.datamodule

+"""Base `Dataloader` class(es) used in `graphnet`."""
+from typing import Dict, Any, Optional, List, Tuple, Union
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+from copy import deepcopy
+from sklearn.model_selection import train_test_split
+import pandas as pd
+
+from graphnet.data.dataset import (
+    Dataset,
+    EnsembleDataset,
+    SQLiteDataset,
+    ParquetDataset,
+)
+from graphnet.utilities.logging import Logger
+
+
+
+[docs] +class GraphNeTDataModule(pl.LightningDataModule, Logger): + """General Class for DataLoader Construction.""" + + def __init__( + self, + dataset_reference: Union[SQLiteDataset, ParquetDataset, Dataset], + dataset_args: Dict[str, Any], + selection: Optional[Union[List[int], List[List[int]]]] = None, + test_selection: Optional[Union[List[int], List[List[int]]]] = None, + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + train_val_split: Optional[List[float]] = [0.9, 0.10], + split_seed: int = 42, + ) -> None: + """Create dataloaders from dataset. + + Args: + dataset_reference: A non-instantiated reference + to the dataset class. + dataset_args: Arguments to instantiate + graphnet.data.dataset.Dataset with. + selection: (Optional) a list of event id's used for training + and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, + Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, + Default None. + validation_dataloader_kwargs: Arguments for the validation + DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, + Default None. + train_val_split (Optional): Split ratio for training and + validation sets. Default is [0.9, 0.10]. + split_seed: seed used for shuffling and splitting selections into + train/validation, Default 42. + """ + Logger.__init__(self) + self._make_sure_root_logger_is_configured() + self._dataset = dataset_reference + self._dataset_args = dataset_args + self._selection = selection + self._test_selection = test_selection + self._train_val_split = train_val_split or [0.0] + self._rng = split_seed + + self._train_dataloader_kwargs = train_dataloader_kwargs or {} + self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} + self._test_dataloader_kwargs = test_dataloader_kwargs or {} + + # If multiple dataset paths are given, we should use EnsembleDataset + self._use_ensemble_dataset = isinstance( + self._dataset_args["path"], list + ) + + self.setup("fit") + +
+[docs] + def prepare_data(self) -> None: + """Prepare the dataset for training.""" + # Download method for curated datasets. Method for download is + # likely dataset-specific, so we can leave it as-is + pass
+ + +
+[docs] + def setup(self, stage: str) -> None: + """Prepare Datasets for DataLoaders. + + Args: + stage: lightning stage. Either "fit, validate, test, predict" + """ + # Sanity Checks + self._validate_dataset_class() + self._validate_dataset_args() + self._validate_dataloader_args() + + # Case-handling of selection arguments + self._resolve_selections() + + # Creation of Datasets + if ( + self._test_selection is not None + or len(self._test_dataloader_kwargs) > 0 + ): + self._test_dataset = self._create_dataset( + self._test_selection # type: ignore + ) + if stage == "fit" or stage == "validate": + if self._train_selection is not None: + self._train_dataset = self._create_dataset( + self._train_selection + ) + if self._val_selection is not None: + self._val_dataset = self._create_dataset(self._val_selection) + + return
+ + + @property + def train_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the training DataLoader. + + Returns: + DataLoader: The DataLoader configured for training. + """ + return self._create_dataloader(self._train_dataset) + + @property + def val_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the validation DataLoader. + + Returns: + DataLoader: The DataLoader configured for validation. + """ + return self._create_dataloader(self._val_dataset) + + @property + def test_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the test DataLoader. + + Returns: + DataLoader: The DataLoader configured for testing. + """ + return self._create_dataloader(self._test_dataset) + +
+[docs] + def teardown(self) -> None: # type: ignore[override] + """Perform any necessary cleanup or shutdown procedures. + + This method can be used for tasks such as closing SQLite connections + after training. Override this method as needed. + + Returns: + None + """ + if hasattr(self, "_train_dataset") and isinstance( + self._train_dataset, SQLiteDataset + ): + self._train_dataset._close_connection() + + if hasattr(self, "_val_dataset") and isinstance( + self._val_dataset, SQLiteDataset + ): + self._val_dataset._close_connection() + + if hasattr(self, "_test_dataset") and isinstance( + self._test_dataset, SQLiteDataset + ): + self._test_dataset._close_connection() + + return
+ + + def _create_dataloader( + self, dataset: Union[Dataset, EnsembleDataset] + ) -> DataLoader: + """Create a DataLoader for the given dataset. + + Args: + dataset (Union[Dataset, EnsembleDataset]): + The dataset to create a DataLoader for. + + Returns: + DataLoader: The DataLoader configured for the given dataset. + """ + if dataset == self._train_dataset: + dataloader_args = self._train_dataloader_kwargs + elif dataset == self._val_dataset: + dataloader_args = self._validation_dataloader_kwargs + elif dataset == self._test_dataset: + dataloader_args = self._test_dataloader_kwargs + else: + raise ValueError( + "Unknown dataset encountered during dataloader creation." + ) + + if dataloader_args is None: + raise AttributeError("Dataloader arguments not provided.") + + return DataLoader(dataset=dataset, **dataloader_args) + + def _validate_dataset_class(self) -> None: + """Sanity checks on the dataset reference (self._dataset). + + Checks whether the dataset is an instance of SQLiteDataset, + ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset + type is detected, or if an EnsembleDataset is used. + """ + allowed_types = (SQLiteDataset, ParquetDataset, Dataset) + if self._dataset not in allowed_types: + raise TypeError( + "dataset_reference must be an instance " + "of SQLiteDataset, ParquetDataset, or Dataset." + ) + if self._dataset is EnsembleDataset: + raise TypeError( + "EnsembleDataset is not allowed as dataset_reference." + ) + + def _validate_dataset_args(self) -> None: + """Sanity checks on the arguments for the dataset reference.""" + if isinstance(self._dataset_args["path"], list): + if self._selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of selections given as arg. + assert len(self._dataset_args["path"]) == len( + self._selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths" + f" ({len(self._dataset_args['path'])})" + " does not match the number of" + f" selections ({len(self._selection)})." + ) + + if self._test_selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of test selections. + assert len(self._dataset_args["path"]) == len( + self._test_selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths " + f" ({len(self._dataset_args['path'])}) does not match " + "the number of test selections " + f"({len(self._test_selection)}).If you'd like to test " + "on only a subset of the " + f"{len(self._dataset_args['path'])} datasets, " + "please provide empty test selections for the others." + ) + + def _validate_dataloader_args(self) -> None: + """Sanity check on `dataloader_args`.""" + if "dataset" in self._train_dataloader_kwargs: + raise ValueError( + "`train_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._validation_dataloader_kwargs: + raise ValueError( + "`validation_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._test_dataloader_kwargs: + raise ValueError( + "`test_dataloader_kwargs` must not contain `dataset`" + ) + + def _resolve_selections(self) -> None: + if self._test_selection is None: + self.warning_once( + f"{self.__class__.__name__} did not receive an" + " argument for `test_selection` and will " + "therefore not have a prediction dataloader available." + ) + if self._selection is not None: + # Split the selection into train/validation + if self._use_ensemble_dataset: + # Split every selection + self._train_selection = [] + self._val_selection = [] + for selection in self._selection: + train_selection, val_selection = self._split_selection( + selection + ) + self._train_selection.append(train_selection) + self._val_selection.append(val_selection) + + else: + # Split the only selection we got + assert isinstance(self._selection, list) + ( + self._train_selection, + self._val_selection, + ) = self._split_selection( # type: ignore + self._selection + ) + + else: # selection is None + # If not provided, we infer it by grabbing + # all event ids in the dataset. + self.info( + f"{self.__class__.__name__} did not receive an" + " for `selection`. Selection will " + "will automatically be created with a split of " + f"train: {self._train_val_split[0]} and " + f"validation: {self._train_val_split[1]}" + ) + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections() # type: ignore + + def _split_selection( + self, selection: Union[int, List[int], List[List[int]]] + ) -> Tuple[List[int], List[int]]: + """Split train selection into train/validation. + + Args: + selection: Training selection to be split + + Returns: + Training selection, Validation selection. + """ + assert isinstance(selection, (int, list)) + if isinstance(selection, int): + flat_selection = [selection] + elif isinstance(selection[0], list): + flat_selection = [ + item + for sublist in selection + for item in sublist # type: ignore + ] + else: + flat_selection = selection # type: ignore + assert isinstance(flat_selection, list) + + train_selection, val_selection = train_test_split( + flat_selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + return train_selection, val_selection + + def _infer_selections(self) -> Tuple[List[int], List[int]]: + """Automatically infer training and validation selections. + + Returns: + Training selection, Validation selection + """ + if self._use_ensemble_dataset: + # We must iterate through the dataset paths and infer a train/val + # selection for each. + self._train_selection = [] + self._val_selection = [] + for dataset_path in self._dataset_args["path"]: + ( + train_selection, + val_selection, + ) = self._infer_selections_on_single_dataset(dataset_path) + self._train_selection.append(train_selection) # type: ignore + self._val_selection.append(val_selection) # type: ignore + else: + # Infer selection on a single dataset + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections_on_single_dataset( # type: ignore + self._dataset_args["path"] + ) + + return (self._train_selection, self._val_selection) # type: ignore + + def _infer_selections_on_single_dataset( + self, dataset_path: str + ) -> Tuple[List[int], List[int]]: + """Automatically infers dataset train/val selections. + + Args: + dataset_path (str): The path to the dataset. + + Returns: + Tuple[List[int], List[int]]: Training and validation selections. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = dataset_path + tmp_dataset = self._construct_dataset(tmp_args) + + all_events = ( + tmp_dataset._get_all_indices() + ) # unshuffled list, sequential index + + # Multiple lines to avoid one large + all_events = ( + pd.DataFrame(all_events) + .sample(frac=1, replace=False, random_state=self._rng) + .values.tolist() + ) # shuffled list + + return self._split_selection(all_events) + + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: + """Construct dataset. + + Return: + Dataset object constructed from input arguments. + """ + dataset = self._dataset(**tmp_args) # type: ignore + return dataset + + def _create_dataset( + self, selection: Union[List[int], List[List[int]], List[float]] + ) -> Union[EnsembleDataset, Dataset]: + """Instantiate `dataset_reference`. + + Args: + selection: The selected event id's. + + Returns: + A dataset, either an instance of `EnsembleDataset` or `Dataset`. + """ + if self._use_ensemble_dataset: + # Construct multiple datasets and pass to EnsembleDataset + # len(selection) == len(dataset_args['path']) + datasets = [] + for dataset_idx in range(len(selection)): + datasets.append( + self._create_single_dataset( + selection=selection[dataset_idx], # type: ignore + path=self._dataset_args["path"][dataset_idx], + ) + ) + + dataset = EnsembleDataset(datasets) + + else: + # Construct single dataset + dataset = self._create_single_dataset( + selection=selection, + path=self._dataset_args["path"], # type:ignore + ) + return dataset + + def _create_single_dataset( + self, + selection: Union[List[int], List[List[int]], List[float]], + path: str, + ) -> Dataset: + """Instantiate a single `Dataset`. + + Args: + selection: A selection for a single dataset. + path: Path to a single dataset + + Returns: + An instance of `Dataset`. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = path + tmp_args["selection"] = selection + return self._construct_dataset(tmp_args)
+ +
+ +
+
+
+
+
+ + + + + \ No newline at end of file diff --git a/_modules/graphnet/models/standard_model.html b/_modules/graphnet/models/standard_model.html index 8d219a1e4..6b72ed310 100644 --- a/_modules/graphnet/models/standard_model.html +++ b/_modules/graphnet/models/standard_model.html @@ -613,6 +613,9 @@

Source code for graph on_step=False, sync_dist=True, ) + + current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] + self.log("lr", current_lr, prog_bar=True, on_step=True) return loss diff --git a/_modules/graphnet/training/utils.html b/_modules/graphnet/training/utils.html index 547bf5619..b2e9890bd 100644 --- a/_modules/graphnet/training/utils.html +++ b/_modules/graphnet/training/utils.html @@ -659,6 +659,24 @@

Source code for graphnet.tra model.save(path + "/" + tag + "_model.pth") Logger().info("Results saved at: \n %s" % path) + + +
+[docs] +def save_selection(selection: List[int], file_path: str) -> None: + """Save the list of event numbers to a CSV file. + + Args: + selection: List of event ids. + file_path: File path to save the selection. + """ + assert isinstance( + selection, list + ), "Selection should be a list of integers." + with open(file_path, "w") as f: + f.write(",".join(map(str, selection))) + f.write("\n")
+ diff --git a/_modules/index.html b/_modules/index.html index b5d45ae22..d74b408da 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -324,6 +324,7 @@

All modules for which code is available

+ +
  • + + + datamodule + +
  • @@ -604,12 +611,12 @@ -
    Skip to content +
    + +
    + + +
    + + + + +
    +
    + +
    +
    +
    + +
    +
    +
    +
    +
    +
    + + +
    +
    +
    + +
    +
    + +
    +

    datamodule

    +

    Base Dataloader class(es) used in graphnet.

    +
    +
    +class graphnet.data.datamodule.GraphNeTDataModule(dataset_reference, dataset_args, selection, test_selection, train_dataloader_kwargs, validation_dataloader_kwargs, test_dataloader_kwargs, train_val_split=[0.9, 0.1], split_seed)[source]
    +

    Bases: LightningDataModule, Logger

    +

    General Class for DataLoader Construction.

    +

    Create dataloaders from dataset.

    +
    +
    Parameters:
    +
      +
    • dataset_reference (Union[SQLiteDataset, ParquetDataset, Dataset]) – A non-instantiated reference +to the dataset class.

    • +
    • dataset_args (Dict[str, Any]) – Arguments to instantiate +graphnet.data.dataset.Dataset with.

    • +
    • selection (Union[List[int], List[List[int]], None], default: None) – (Optional) a list of event id’s used for training +and validation, Default None.

    • +
    • test_selection (Union[List[int], List[List[int]], None], default: None) – (Optional) a list of event id’s used for testing, +Default None.

    • +
    • train_dataloader_kwargs (Optional[Dict[str, Any]], default: None) – Arguments for the training DataLoader, +Default None.

    • +
    • validation_dataloader_kwargs (Optional[Dict[str, Any]], default: None) – Arguments for the validation +DataLoader, Default None.

    • +
    • test_dataloader_kwargs (Optional[Dict[str, Any]], default: None) – Arguments for the test DataLoader, +Default None.

    • +
    • train_val_split (Optional) – Split ratio for training and +validation sets. Default is [0.9, 0.10].

    • +
    • split_seed (int, default: 42) – seed used for shuffling and splitting selections into +train/validation, Default 42.

    • +
    +
    +
    +
    +
    +prepare_data()[source]
    +

    Prepare the dataset for training.

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +setup(stage)[source]
    +

    Prepare Datasets for DataLoaders.

    +
    +
    Parameters:
    +

    stage (str) – lightning stage. Either “fit, validate, test, predict”

    +
    +
    Return type:
    +

    None

    +
    +
    +
    +
    +
    +property train_dataloader: DataLoader
    +

    Prepare and return the training DataLoader.

    +
    +
    Returns:
    +

    The DataLoader configured for training.

    +
    +
    Return type:
    +

    DataLoader

    +
    +
    +
    +
    +
    +property val_dataloader: DataLoader
    +

    Prepare and return the validation DataLoader.

    +
    +
    Returns:
    +

    The DataLoader configured for validation.

    +
    +
    Return type:
    +

    DataLoader

    +
    +
    +
    +
    +
    +property test_dataloader: DataLoader
    +

    Prepare and return the test DataLoader.

    +
    +
    Returns:
    +

    The DataLoader configured for testing.

    +
    +
    Return type:
    +

    DataLoader

    +
    +
    +
    +
    +
    +teardown()[source]
    +

    Perform any necessary cleanup or shutdown procedures.

    +

    This method can be used for tasks such as closing SQLite connections +after training. Override this method as needed.

    +
    +
    Return type:
    +

    None

    +
    +
    Returns:
    +

    None

    +
    +
    +
    +
    +
    + + +
    +
    +
    +
    +
    + + + + + \ No newline at end of file diff --git a/api/graphnet.data.dataset.dataset.html b/api/graphnet.data.dataset.dataset.html index fae36ebe0..bfd116024 100644 --- a/api/graphnet.data.dataset.dataset.html +++ b/api/graphnet.data.dataset.dataset.html @@ -498,6 +498,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.dataset.html b/api/graphnet.data.dataset.html index 4dc0f578f..cefcbc843 100644 --- a/api/graphnet.data.dataset.html +++ b/api/graphnet.data.dataset.html @@ -391,6 +391,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.dataset.parquet.html b/api/graphnet.data.dataset.parquet.html index d66b67d51..25a99d006 100644 --- a/api/graphnet.data.dataset.parquet.html +++ b/api/graphnet.data.dataset.parquet.html @@ -399,6 +399,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.dataset.parquet.parquet_dataset.html b/api/graphnet.data.dataset.parquet.parquet_dataset.html index 349f42c6b..7461072d4 100644 --- a/api/graphnet.data.dataset.parquet.parquet_dataset.html +++ b/api/graphnet.data.dataset.parquet.parquet_dataset.html @@ -425,6 +425,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.dataset.sqlite.html b/api/graphnet.data.dataset.sqlite.html index 8c8466a90..7cfb3ec50 100644 --- a/api/graphnet.data.dataset.sqlite.html +++ b/api/graphnet.data.dataset.sqlite.html @@ -399,6 +399,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html index e7c699424..cabd2fa0b 100644 --- a/api/graphnet.data.dataset.sqlite.sqlite_dataset.html +++ b/api/graphnet.data.dataset.sqlite.sqlite_dataset.html @@ -425,6 +425,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.html b/api/graphnet.data.extractors.html index 602f50a76..386e3f8af 100644 --- a/api/graphnet.data.extractors.html +++ b/api/graphnet.data.extractors.html @@ -461,6 +461,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3extractor.html b/api/graphnet.data.extractors.i3extractor.html index 088aa5ada..f13c0bef6 100644 --- a/api/graphnet.data.extractors.i3extractor.html +++ b/api/graphnet.data.extractors.i3extractor.html @@ -516,6 +516,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3featureextractor.html b/api/graphnet.data.extractors.i3featureextractor.html index ac3be3fbb..36d5ebe5e 100644 --- a/api/graphnet.data.extractors.i3featureextractor.html +++ b/api/graphnet.data.extractors.i3featureextractor.html @@ -512,6 +512,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3genericextractor.html b/api/graphnet.data.extractors.i3genericextractor.html index ace919977..ea92ae1de 100644 --- a/api/graphnet.data.extractors.i3genericextractor.html +++ b/api/graphnet.data.extractors.i3genericextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3hybridrecoextractor.html b/api/graphnet.data.extractors.i3hybridrecoextractor.html index 2219ddab5..f9878446e 100644 --- a/api/graphnet.data.extractors.i3hybridrecoextractor.html +++ b/api/graphnet.data.extractors.i3hybridrecoextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html index c8a12e6b6..3a9424cce 100644 --- a/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html +++ b/api/graphnet.data.extractors.i3ntmuonlabelsextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3particleextractor.html b/api/graphnet.data.extractors.i3particleextractor.html index d0d8cf1bc..6835ca634 100644 --- a/api/graphnet.data.extractors.i3particleextractor.html +++ b/api/graphnet.data.extractors.i3particleextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3pisaextractor.html b/api/graphnet.data.extractors.i3pisaextractor.html index 6bbf3f02b..6813730e8 100644 --- a/api/graphnet.data.extractors.i3pisaextractor.html +++ b/api/graphnet.data.extractors.i3pisaextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3quesoextractor.html b/api/graphnet.data.extractors.i3quesoextractor.html index 42497984b..ac54a3905 100644 --- a/api/graphnet.data.extractors.i3quesoextractor.html +++ b/api/graphnet.data.extractors.i3quesoextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3retroextractor.html b/api/graphnet.data.extractors.i3retroextractor.html index f7aee6ee2..230065dd6 100644 --- a/api/graphnet.data.extractors.i3retroextractor.html +++ b/api/graphnet.data.extractors.i3retroextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3splinempeextractor.html b/api/graphnet.data.extractors.i3splinempeextractor.html index f9755a130..121cbff6c 100644 --- a/api/graphnet.data.extractors.i3splinempeextractor.html +++ b/api/graphnet.data.extractors.i3splinempeextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3truthextractor.html b/api/graphnet.data.extractors.i3truthextractor.html index ded247fae..bf98709d4 100644 --- a/api/graphnet.data.extractors.i3truthextractor.html +++ b/api/graphnet.data.extractors.i3truthextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.i3tumextractor.html b/api/graphnet.data.extractors.i3tumextractor.html index 2bda092dc..f32febcbb 100644 --- a/api/graphnet.data.extractors.i3tumextractor.html +++ b/api/graphnet.data.extractors.i3tumextractor.html @@ -476,6 +476,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.utilities.collections.html b/api/graphnet.data.extractors.utilities.collections.html index 2ba4e4c1a..2d75afdc0 100644 --- a/api/graphnet.data.extractors.utilities.collections.html +++ b/api/graphnet.data.extractors.utilities.collections.html @@ -516,6 +516,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.utilities.frames.html b/api/graphnet.data.extractors.utilities.frames.html index 596e21fca..615ecb4cc 100644 --- a/api/graphnet.data.extractors.utilities.frames.html +++ b/api/graphnet.data.extractors.utilities.frames.html @@ -516,6 +516,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.utilities.html b/api/graphnet.data.extractors.utilities.html index 4240ace5f..79d4aa2d9 100644 --- a/api/graphnet.data.extractors.utilities.html +++ b/api/graphnet.data.extractors.utilities.html @@ -483,6 +483,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.extractors.utilities.types.html b/api/graphnet.data.extractors.utilities.types.html index cbf5c13bb..82941ae6d 100644 --- a/api/graphnet.data.extractors.utilities.types.html +++ b/api/graphnet.data.extractors.utilities.types.html @@ -570,6 +570,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • diff --git a/api/graphnet.data.filters.html b/api/graphnet.data.filters.html index ab303e539..91e99443b 100644 --- a/api/graphnet.data.filters.html +++ b/api/graphnet.data.filters.html @@ -130,7 +130,7 @@ - + @@ -361,6 +361,13 @@ dataloader +
  • +
  • + + + datamodule + +
  • @@ -573,7 +580,7 @@
  • +
  • + + + datamodule + +
  • @@ -520,6 +527,10 @@
  • DataLoader
  • +
  • datamodule +
  • filters diff --git a/api/graphnet.training.html b/api/graphnet.training.html index d33bd01fc..43a1760a0 100644 --- a/api/graphnet.training.html +++ b/api/graphnet.training.html @@ -456,6 +456,7 @@
  • make_train_validation_dataloader()
  • get_predictions()
  • save_results()
  • +
  • save_selection()
  • weight_fitting
  • @@ -427,6 +429,13 @@ save_results() + +
  • + + + save_selection() + +
  • @@ -482,6 +491,8 @@
  • get_predictions()
  • save_results() +
  • +
  • save_selection()
  • @@ -631,6 +642,22 @@ +
    +
    +graphnet.training.utils.save_selection(selection, file_path)[source]
    +

    Save the list of event numbers to a CSV file.

    +
    +
    Parameters:
    +
      +
    • selection (List[int]) – List of event ids.

    • +
    • file_path (str) – File path to save the selection.

    • +
    +
    +
    Return type:
    +

    None

    +
    +
    +
    diff --git a/genindex.html b/genindex.html index 5138df542..cfa4f7184 100644 --- a/genindex.html +++ b/genindex.html @@ -833,6 +833,13 @@

    G

    +
  • + graphnet.data.datamodule + +
  • @@ -1474,6 +1481,8 @@

    G

  • module
  • +
  • GraphNeTDataModule (class in graphnet.data.datamodule) +
  • GraphnetEarlyStopping (class in graphnet.training.callbacks)
  • GraphNeTI3Module (class in graphnet.deployment.i3modules.graphnet_module) @@ -1765,6 +1774,8 @@

    M

  • graphnet.data.dataconverter
  • graphnet.data.dataloader +
  • +
  • graphnet.data.datamodule
  • graphnet.data.dataset
  • @@ -2079,10 +2090,10 @@

    P

  • plot_1D_contour() (in module graphnet.pisa.plotting)
  • - - + + -