From 146f66ed6c6507c39dde170c5285199dd16da01c Mon Sep 17 00:00:00 2001 From: samadpls Date: Sun, 3 Dec 2023 01:32:52 +0500 Subject: [PATCH 01/14] Added flexible DataLoader configurations to `GraphNeTDataModule` Signed-off-by: samadpls --- setup.py | 1 + src/graphnet/data/datamodule.py | 338 ++++++++++++++++++++++++++++++++ src/graphnet/training/utils.py | 13 ++ 3 files changed, 352 insertions(+) create mode 100644 src/graphnet/data/datamodule.py diff --git a/setup.py b/setup.py index 3b70233ab..1ec34a0f0 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ "timer>=0.2", "tqdm>=4.64", "wandb>=0.12", + "pytorch-lightning>=2.1.2", ] EXTRAS_REQUIRE = { diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py new file mode 100644 index 000000000..1f8398c8b --- /dev/null +++ b/src/graphnet/data/datamodule.py @@ -0,0 +1,338 @@ +from typing import Dict, Any, Optional, List, Tuple, Union +import lightning as L +from torch.utils.data import DataLoader +from copy import deepcopy +from sklearn.model_selection import train_test_split +import pandas as pd + +from graphnet.data.dataset import ( + Dataset, + EnsembleDataset, + SQLiteDataset, + ParquetDataset, +) +from graphnet.utilities.logging import Logger +from graphnet.training.utils import save_selection + + +class GraphNeTDataModule(L.LightningDataModule, Logger): + """General Class for DataLoader Construction.""" + + def __init__( + self, + dataset_reference: Union[SQLiteDataset, ParquetDataset, Dataset], + selection: Optional[Union[List[int], List[List[int]]]], + test_selection: Optional[Union[List[int], List[List[int]]]], + dataset_args: Dict[str, Any], + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + train_val_split: Optional[List[float, float]] = [0.9, 0.10], + split_seed: int = 42, + ) -> None: + """Create dataloaders from dataset. + + Args: + dataset_reference: A non-instantiated reference to the dataset class. + selection: (Optional) a list of event id's used for training and validation. + test_selection: (Optional) a list of event id's used for testing. + dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. + train_dataloader_kwargs: Arguments for the training DataLoader. + validation_dataloader_kwargs: Arguments for the validation DataLoader. + test_dataloader_kwargs: Arguments for the test DataLoader. + split_seed: seed used for shuffling and splitting selections into train/validation. + """ + self._dataset = dataset_reference + self._selection = selection + self._train_val_split = train_val_split + self._test_selection = test_selection + self._dataset_args = dataset_args + self._rng = split_seed + + self._train_dataloader_kwargs = train_dataloader_kwargs or {} + self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} + self._test_dataloader_kwargs = test_dataloader_kwargs or {} + + # If multiple dataset paths are given, we should use EnsembleDataset + self._use_ensemble_dataset = isinstance( + self._dataset_args["path"], list + ) + + def prepare_data(self) -> None: + """Prepare the dataset for training.""" + # Download method for curated datasets. Method for download is + # likely dataset-specific, so we can leave it as-is + pass + + def setup(self, stage: str) -> None: + """Prepare Datasets for DataLoaders. + + Args: + stage: lightning stage. Either "fit, validate, test, predict" + """ + # Sanity Checks + self._validate_dataset_class() + self._validate_dataset_args() + self._validate_dataloader_args() + + # Case-handling of selection arguments + self._resolve_selections() + + # Creation of Datasets + self._train_dataset = self._create_dataset(self._train_selection) + self._val_dataset = self._create_dataset(self._val_selection) + self._test_dataset = self._create_dataset(self._test_selection) + + return + + def train_dataloader(self) -> DataLoader: + """Prepare and return the training DataLoader. + + Returns: + DataLoader: The DataLoader configured for training. + """ + return self._create_dataloader(self._train_dataset) + + def val_dataloader(self) -> DataLoader: + """Prepare and return the validation DataLoader. + + Returns: + DataLoader: The DataLoader configured for validation. + """ + return self._create_dataloader(self._val_dataset) + + def test_dataloader(self) -> DataLoader: + """Prepare and return the test DataLoader. + + Returns: + DataLoader: The DataLoader configured for testing. + """ + return self._create_dataloader(self._test_dataset) + + def teardown(self) -> None: + """Perform any necessary cleanup or shutdown procedures. + + This method can be used for tasks such as closing SQLite connections + after training. Override this method as needed. + + Returns: + None + """ + pass + + def _create_dataloader( + self, dataset: Union[Dataset, EnsembleDataset] + ) -> DataLoader: + """Create a DataLoader for the given dataset. + + Args: + dataset (Union[Dataset, EnsembleDataset]): The dataset to create a DataLoader for. + + Returns: + DataLoader: The DataLoader configured for the given dataset. + """ + return DataLoader(dataset=dataset, **self._dataloader_args) + + def _validate_dataset_class(self) -> None: + """Sanity checks on the dataset reference (self._dataset). + + Is it a GraphNeT-compatible dataset? has the class already been + instantiated? Did they try to pass EnsembleDataset? + """ + if not isinstance( + self._dataset, (SQLiteDataset, ParquetDataset, Dataset) + ): + raise TypeError( + "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset." + ) + if isinstance(self._dataset, EnsembleDataset): + raise TypeError( + "EnsembleDataset is not allowed as dataset_reference." + ) + + def _validate_dataset_args(self) -> None: + """Sanity checks on the arguments for the dataset reference.""" + if isinstance(self._dataset_args["path"], list): + if self._selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of selections given as arg. + assert len(self._dataset_args["path"]) == len( + self._selection + ) + except AssertionError: + raise ValueError( + f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of selections ({len(self._selection)})." + ) + + if self._test_selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of test selections. + assert len(self._dataset_args["path"]) == len( + self._test_selection + ) + except AssertionError: + raise ValueError( + f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of test selections ({len(self._test_selection)}). If you'd like to test on only a subset of the {len(self._dataset_args['path'])} datasets, please provide empty test selections for the others." + ) + + def _validate_dataloader_args(self) -> None: + """Sanity check on `dataloader_args`.""" + if "dataset" in self._dataloader_args: + raise ValueError("`dataloader_args` must not contain `dataset`") + + def _resolve_selections(self) -> None: + if self._test_selection is None: + self.warning_once( + f"{self.__class__.__name__} did not receive an argument for `test_selection` and will therefore not have a prediction dataloader available." + ) + if self._selection is not None: + # Split the selection into train/validation + if self._use_ensemble_dataset: + # Split every selection + self._train_selection = [] + self._val_selection = [] + for selection in self._selection: + train_selection, val_selection = self._split_selection( + selection + ) + self._train_selection.append(train_selection) + self._val_selection.append(val_selection) + + else: + # Split the only selection we got + ( + self._train_selection, + self._val_selection, + ) = self._split_selection(self._selection) + + if self._selection is None: + # If not provided, we infer it by grabbing all event ids in the dataset. + self.info( + f"{self.__class__.__name__} did not receive an argument for `selection`. Selection will automatically be created with a split of train: {self._train_val_split[0]} and validation: {self._train_val_split[1]}" + ) + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections() + + def _split_selection( + self, selection: List[int] + ) -> Tuple[List[int], List[int]]: + """Split train selection into train/validation. + + Args: + selection: Training selection to be split + + Returns: + Training selection, Validation selection. + """ + train_selection, val_selection = train_test_split( + selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + return train_selection, val_selection + + def _infer_selections(self) -> Tuple[List[int], List[int]]: + """Automatically infer training and validation selections. + + Returns: + Training selection, Validation selection + """ + if self._use_ensemble_dataset: + # We must iterate through the dataset paths and infer a train/val + # selection for each. + self._train_selection = [] + self._val_selection = [] + for dataset_path in self._dataset_args["path"]: + ( + train_selection, + val_selection, + ) = self._infer_selections_on_single_dataset(dataset_path) + self._train_selection.append(train_selection) + self._val_selection.append(val_selection) + else: + # Infer selection on a single dataset + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections_on_single_dataset( + self._dataset_args["path"] + ) + + def _infer_selections_on_single_dataset( + self, dataset_path: str + ) -> Tuple[List[int], List[int]]: + """Automatically infers training and validation selections for a single dataset. + + Args: + dataset_path (str): The path to the dataset. + + Returns: + Tuple[List[int], List[int]]: Training and validation selections. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = dataset_path + tmp_dataset = self._construct_dataset(tmp_args) + + all_events = tmp_dataset._get_all_indices() # unshuffled list + + # Multiple lines to avoid one large + all_events = pd.DataFrame(all_events).sample( + frac=1, replace=False, random_state=self._rng + ) + + all_events = all_events.values.tolist() # shuffled list + return self._split_selection(all_events) + + def _create_dataset( + self, selection: Union[List[int], List[List[int]]] + ) -> Union[EnsembleDataset, Dataset]: + """Instantiate `dataset_reference`. + + Args: + selection: The selected event id's. + + Returns: + A dataset, either an instance of `EnsembleDataset` or `Dataset`. + """ + if self._use_ensemble_dataset: + # Construct multiple datasets and pass to EnsembleDataset + # At this point, we have checked that len(selection) == len(dataset_args['path']) + datasets = [] + for dataset_idx in range(len(selection)): + datasets.append( + self._create_single_dataset( + selection=selection[dataset_idx], + path=self._dataset_args["path"][dataset_idx], + ) + ) + + dataset = EnsembleDataset(datasets) + + else: + # Construct single dataset + dataset = self._create_single_dataset( + selection=selection, path=self._dataset_args["path"] + ) + return dataset + + def _create_single_dataset( + self, selection: List[int], path: str + ) -> Dataset: + """Instantiate a single `Dataset`. + + Args: + selection: A selection for a single dataset. + path: Path to a single dataset + + Returns: + An instance of `Dataset`. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = path + tmp_args["selection"] = selection + return self._dataset(**tmp_args) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index df7c92e15..b33089ec9 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -317,3 +317,16 @@ def save_results( model.save_state_dict(path + "/" + tag + "_state_dict.pth") model.save(path + "/" + tag + "_model.pth") Logger().info("Results saved at: \n %s" % path) + + +def save_selection(selection: List[int], file_path: str) -> None: + """Save the list of event numbers to a CSV file. + + Args: + selection: List of event ids. + file_path: File path to save the selection. + """ + with open(file_path, "w") as file: + file.write("event_id\n") + for event_id in selection: + file.write(f"{event_id}\n") From 2fb0bef4718b13a887f1d8c77ef8af1c0b2ecb57 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sun, 3 Dec 2023 15:11:16 +0500 Subject: [PATCH 02/14] refactored the coding style --- src/graphnet/data/datamodule.py | 41 +++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 1f8398c8b..605cbb429 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -1,3 +1,4 @@ +"""Base `Dataloader` class(es) used in `graphnet`.""" from typing import Dict, Any, Optional, List, Tuple, Union import lightning as L from torch.utils.data import DataLoader @@ -27,7 +28,7 @@ def __init__( train_dataloader_kwargs: Optional[Dict[str, Any]] = None, validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, test_dataloader_kwargs: Optional[Dict[str, Any]] = None, - train_val_split: Optional[List[float, float]] = [0.9, 0.10], + train_val_split: Optional[List[float]] = [0.9, 0.10], split_seed: int = 42, ) -> None: """Create dataloaders from dataset. @@ -40,6 +41,7 @@ def __init__( train_dataloader_kwargs: Arguments for the training DataLoader. validation_dataloader_kwargs: Arguments for the validation DataLoader. test_dataloader_kwargs: Arguments for the test DataLoader. + train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. split_seed: seed used for shuffling and splitting selections into train/validation. """ self._dataset = dataset_reference @@ -81,7 +83,8 @@ def setup(self, stage: str) -> None: # Creation of Datasets self._train_dataset = self._create_dataset(self._train_selection) self._val_dataset = self._create_dataset(self._val_selection) - self._test_dataset = self._create_dataset(self._test_selection) + if self._test_selection is not None: + self._test_dataset = self._create_dataset(self._test_selection) return @@ -191,8 +194,8 @@ def _resolve_selections(self) -> None: # Split the selection into train/validation if self._use_ensemble_dataset: # Split every selection - self._train_selection = [] - self._val_selection = [] + self._train_selection: List[List[int]] = [] + self._val_selection: List[List[int]] = [] for selection in self._selection: train_selection, val_selection = self._split_selection( selection @@ -218,7 +221,7 @@ def _resolve_selections(self) -> None: ) = self._infer_selections() def _split_selection( - self, selection: List[int] + self, selection: Union[int, List[int], List[List[int]]] ) -> Tuple[List[int], List[int]]: """Split train selection into train/validation. @@ -228,12 +231,26 @@ def _split_selection( Returns: Training selection, Validation selection. """ - train_selection, val_selection = train_test_split( - selection, - train_size=self._train_val_split[0], - test_size=self._train_val_split[1], - random_state=self._rng, - ) + if isinstance(selection, int): + train_selection, val_selection = [selection], [] + elif isinstance(selection[0], list): + flat_selection = [ + item for sublist in selection for item in sublist + ] + train_selection, val_selection = train_test_split( + flat_selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + else: + train_selection, val_selection = train_test_split( + selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + return train_selection, val_selection def _infer_selections(self) -> Tuple[List[int], List[int]]: @@ -266,7 +283,7 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: def _infer_selections_on_single_dataset( self, dataset_path: str ) -> Tuple[List[int], List[int]]: - """Automatically infers training and validation selections for a single dataset. + """Automatically infers dataset train/val selections. Args: dataset_path (str): The path to the dataset. From f4fbd05c7ec1d1f9e7abe2ffb126bf0c186da892 Mon Sep 17 00:00:00 2001 From: samadpls Date: Mon, 4 Dec 2023 12:25:22 +0500 Subject: [PATCH 03/14] Refactored the coding style --- setup.py | 1 - src/graphnet/data/datamodule.py | 89 +++++++++++++++++++++------------ 2 files changed, 57 insertions(+), 33 deletions(-) diff --git a/setup.py b/setup.py index 1ec34a0f0..3b70233ab 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ "timer>=0.2", "tqdm>=4.64", "wandb>=0.12", - "pytorch-lightning>=2.1.2", ] EXTRAS_REQUIRE = { diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 605cbb429..aec94a481 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -1,6 +1,6 @@ """Base `Dataloader` class(es) used in `graphnet`.""" from typing import Dict, Any, Optional, List, Tuple, Union -import lightning as L +import pytorch_lightning as pl from torch.utils.data import DataLoader from copy import deepcopy from sklearn.model_selection import train_test_split @@ -16,7 +16,7 @@ from graphnet.training.utils import save_selection -class GraphNeTDataModule(L.LightningDataModule, Logger): +class GraphNeTDataModule(pl.LightningDataModule, Logger): """General Class for DataLoader Construction.""" def __init__( @@ -45,9 +45,9 @@ def __init__( split_seed: seed used for shuffling and splitting selections into train/validation. """ self._dataset = dataset_reference - self._selection = selection - self._train_val_split = train_val_split - self._test_selection = test_selection + self._selection = selection or [0] + self._train_val_split = train_val_split or [0.0] + self._test_selection = test_selection or [0.0] self._dataset_args = dataset_args self._rng = split_seed @@ -83,8 +83,7 @@ def setup(self, stage: str) -> None: # Creation of Datasets self._train_dataset = self._create_dataset(self._train_selection) self._val_dataset = self._create_dataset(self._val_selection) - if self._test_selection is not None: - self._test_dataset = self._create_dataset(self._test_selection) + self._test_dataset = self._create_dataset(self._test_selection) return @@ -112,16 +111,16 @@ def test_dataloader(self) -> DataLoader: """ return self._create_dataloader(self._test_dataset) - def teardown(self) -> None: - """Perform any necessary cleanup or shutdown procedures. + # def teardown(self) -> None: + # """Perform any necessary cleanup or shutdown procedures. - This method can be used for tasks such as closing SQLite connections - after training. Override this method as needed. + # This method can be used for tasks such as closing SQLite connections + # after training. Override this method as needed. - Returns: - None - """ - pass + # Returns: + # None + # """ + # return None def _create_dataloader( self, dataset: Union[Dataset, EnsembleDataset] @@ -134,7 +133,18 @@ def _create_dataloader( Returns: DataLoader: The DataLoader configured for the given dataset. """ - return DataLoader(dataset=dataset, **self._dataloader_args) + if dataset == self._train_dataset: + dataloader_args = self._train_dataloader_kwargs + elif dataset == self._val_dataset: + dataloader_args = self._validation_dataloader_kwargs + elif dataset == self._test_dataset: + dataloader_args = self._test_dataloader_kwargs + else: + raise ValueError( + "Unknown dataset encountered during dataloader creation." + ) + + return DataLoader(dataset=dataset, **dataloader_args) def _validate_dataset_class(self) -> None: """Sanity checks on the dataset reference (self._dataset). @@ -182,8 +192,18 @@ def _validate_dataset_args(self) -> None: def _validate_dataloader_args(self) -> None: """Sanity check on `dataloader_args`.""" - if "dataset" in self._dataloader_args: - raise ValueError("`dataloader_args` must not contain `dataset`") + if "dataset" in self._train_dataloader_kwargs: + raise ValueError( + "`train_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._validation_dataloader_kwargs: + raise ValueError( + "`validation_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._test_dataloader_kwargs: + raise ValueError( + "`test_dataloader_kwargs` must not contain `dataset`" + ) def _resolve_selections(self) -> None: if self._test_selection is None: @@ -232,25 +252,20 @@ def _split_selection( Training selection, Validation selection. """ if isinstance(selection, int): - train_selection, val_selection = [selection], [] + flat_selection = [selection] elif isinstance(selection[0], list): flat_selection = [ item for sublist in selection for item in sublist ] - train_selection, val_selection = train_test_split( - flat_selection, - train_size=self._train_val_split[0], - test_size=self._train_val_split[1], - random_state=self._rng, - ) else: - train_selection, val_selection = train_test_split( - selection, - train_size=self._train_val_split[0], - test_size=self._train_val_split[1], - random_state=self._rng, - ) + flat_selection = selection + train_selection, val_selection = train_test_split( + flat_selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) return train_selection, val_selection def _infer_selections(self) -> Tuple[List[int], List[int]]: @@ -280,6 +295,8 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: self._dataset_args["path"] ) + return (self._train_selection, self._val_selection) + def _infer_selections_on_single_dataset( self, dataset_path: str ) -> Tuple[List[int], List[int]]: @@ -305,8 +322,16 @@ def _infer_selections_on_single_dataset( all_events = all_events.values.tolist() # shuffled list return self._split_selection(all_events) + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dict[str, Any]: + """Construct dataset.""" + return tmp_args + + def _get_all_indices(self): + """Shuffle the list.""" + return list + def _create_dataset( - self, selection: Union[List[int], List[List[int]]] + self, selection: Union[List[int], List[List[int]], List[float]] ) -> Union[EnsembleDataset, Dataset]: """Instantiate `dataset_reference`. From 302c808877e5bab01aa2890460bf794bef4a0c34 Mon Sep 17 00:00:00 2001 From: samadpls Date: Mon, 18 Dec 2023 23:04:06 +0500 Subject: [PATCH 04/14] updated `datamodule.py` file --- src/graphnet/data/datamodule.py | 64 ++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index aec94a481..c16bee977 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -111,16 +111,16 @@ def test_dataloader(self) -> DataLoader: """ return self._create_dataloader(self._test_dataset) - # def teardown(self) -> None: - # """Perform any necessary cleanup or shutdown procedures. + def teardown(self) -> None: # type: ignore[override] + """Perform any necessary cleanup or shutdown procedures. - # This method can be used for tasks such as closing SQLite connections - # after training. Override this method as needed. + This method can be used for tasks such as closing SQLite connections + after training. Override this method as needed. - # Returns: - # None - # """ - # return None + Returns: + None + """ + return None def _create_dataloader( self, dataset: Union[Dataset, EnsembleDataset] @@ -214,8 +214,8 @@ def _resolve_selections(self) -> None: # Split the selection into train/validation if self._use_ensemble_dataset: # Split every selection - self._train_selection: List[List[int]] = [] - self._val_selection: List[List[int]] = [] + self._train_selection = [] + self._val_selection = [] for selection in self._selection: train_selection, val_selection = self._split_selection( selection @@ -225,10 +225,13 @@ def _resolve_selections(self) -> None: else: # Split the only selection we got + assert isinstance(self._selection, list) ( self._train_selection, self._val_selection, - ) = self._split_selection(self._selection) + ) = self._split_selection( # type: ignore + self._selection + ) if self._selection is None: # If not provided, we infer it by grabbing all event ids in the dataset. @@ -251,14 +254,16 @@ def _split_selection( Returns: Training selection, Validation selection. """ + assert isinstance(selection, (int, list)) if isinstance(selection, int): flat_selection = [selection] elif isinstance(selection[0], list): flat_selection = [ - item for sublist in selection for item in sublist + item for sublist in selection for item in sublist # type: ignore ] else: - flat_selection = selection + flat_selection = selection # type: ignore + assert isinstance(flat_selection, list) train_selection, val_selection = train_test_split( flat_selection, @@ -284,8 +289,8 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: train_selection, val_selection, ) = self._infer_selections_on_single_dataset(dataset_path) - self._train_selection.append(train_selection) - self._val_selection.append(val_selection) + self._train_selection.extend(train_selection) # type: ignore + self._val_selection.extend(val_selection) # type: ignore else: # Infer selection on a single dataset ( @@ -295,7 +300,7 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: self._dataset_args["path"] ) - return (self._train_selection, self._val_selection) + return (self._train_selection, self._val_selection) # type: ignore def _infer_selections_on_single_dataset( self, dataset_path: str @@ -322,14 +327,23 @@ def _infer_selections_on_single_dataset( all_events = all_events.values.tolist() # shuffled list return self._split_selection(all_events) - def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dict[str, Any]: - """Construct dataset.""" - return tmp_args - def _get_all_indices(self): - """Shuffle the list.""" + """Get all indices. + + Return: + List of indices in an unshuffled order. + """ return list + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dict[str, Any]: + """Construct dataset. + + Return: + Dataset object constructed from input arguments. + """ + # instance dataset class , that set of argunment , + return tmp_args + def _create_dataset( self, selection: Union[List[int], List[List[int]], List[float]] ) -> Union[EnsembleDataset, Dataset]: @@ -348,7 +362,7 @@ def _create_dataset( for dataset_idx in range(len(selection)): datasets.append( self._create_single_dataset( - selection=selection[dataset_idx], + selection=selection[dataset_idx], # type: ignore path=self._dataset_args["path"][dataset_idx], ) ) @@ -358,12 +372,14 @@ def _create_dataset( else: # Construct single dataset dataset = self._create_single_dataset( - selection=selection, path=self._dataset_args["path"] + selection=selection, path=self._dataset_args["path"] # type: ignore ) return dataset def _create_single_dataset( - self, selection: List[int], path: str + self, + selection: Union[List[int], List[List[int]], List[float]], + path: str, ) -> Dataset: """Instantiate a single `Dataset`. From 0c4a9c69c08f005966d94da9ce655036a818a9fd Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 23 Dec 2023 21:47:23 +0500 Subject: [PATCH 05/14] Refactored `GraphNeTDataModule` class Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 55 ++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index c16bee977..1d51a1658 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -5,6 +5,7 @@ from copy import deepcopy from sklearn.model_selection import train_test_split import pandas as pd +import random from graphnet.data.dataset import ( Dataset, @@ -120,7 +121,22 @@ def teardown(self) -> None: # type: ignore[override] Returns: None """ - return None + if hasattr(self, "_train_dataset") and isinstance( + self._train_dataset, SQLiteDataset + ): + self._train_dataset._close_connection() + + if hasattr(self, "_val_dataset") and isinstance( + self._val_dataset, SQLiteDataset + ): + self._val_dataset._close_connection() + + if hasattr(self, "_test_dataset") and isinstance( + self._test_dataset, SQLiteDataset + ): + self._test_dataset._close_connection() + + return def _create_dataloader( self, dataset: Union[Dataset, EnsembleDataset] @@ -149,8 +165,9 @@ def _create_dataloader( def _validate_dataset_class(self) -> None: """Sanity checks on the dataset reference (self._dataset). - Is it a GraphNeT-compatible dataset? has the class already been - instantiated? Did they try to pass EnsembleDataset? + Checks whether the dataset is an instance of SQLiteDataset, + ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset + type is detected, or if an EnsembleDataset is used. """ if not isinstance( self._dataset, (SQLiteDataset, ParquetDataset, Dataset) @@ -296,7 +313,7 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: ( self._train_selection, self._val_selection, - ) = self._infer_selections_on_single_dataset( + ) = self._infer_selections_on_single_dataset( # type: ignore self._dataset_args["path"] ) @@ -317,32 +334,46 @@ def _infer_selections_on_single_dataset( tmp_args["path"] = dataset_path tmp_dataset = self._construct_dataset(tmp_args) - all_events = tmp_dataset._get_all_indices() # unshuffled list + all_events = ( + tmp_dataset._get_all_indices() + ) # unshuffled list, # sequential index # Multiple lines to avoid one large all_events = pd.DataFrame(all_events).sample( frac=1, replace=False, random_state=self._rng ) - all_events = all_events.values.tolist() # shuffled list + all_events = random.sample( + all_events, len(all_events) + ) # shuffled list return self._split_selection(all_events) - def _get_all_indices(self): + def _get_all_indices(self) -> List[int]: """Get all indices. Return: List of indices in an unshuffled order. """ - return list + if self._use_ensemble_dataset: + all_indices = [] + for dataset_path in self._dataset_args["path"]: + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = dataset_path + tmp_dataset = self._construct_dataset(tmp_args) + all_indices.extend(tmp_dataset._get_all_indices()) + else: + all_indices = self._dataset._get_all_indices() + + return all_indices - def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dict[str, Any]: + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: """Construct dataset. Return: Dataset object constructed from input arguments. """ - # instance dataset class , that set of argunment , - return tmp_args + dataset = self._dataset(**tmp_args) + return dataset def _create_dataset( self, selection: Union[List[int], List[List[int]], List[float]] @@ -393,4 +424,4 @@ def _create_single_dataset( tmp_args = deepcopy(self._dataset_args) tmp_args["path"] = path tmp_args["selection"] = selection - return self._dataset(**tmp_args) + return self._construct_dataset(tmp_args) From e0ea137e955eb5c6159e07066f383442a7e69212 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sun, 24 Dec 2023 15:41:25 +0500 Subject: [PATCH 06/14] added `_construct_dataset` method Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 1d51a1658..c8268c990 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -348,24 +348,6 @@ def _infer_selections_on_single_dataset( ) # shuffled list return self._split_selection(all_events) - def _get_all_indices(self) -> List[int]: - """Get all indices. - - Return: - List of indices in an unshuffled order. - """ - if self._use_ensemble_dataset: - all_indices = [] - for dataset_path in self._dataset_args["path"]: - tmp_args = deepcopy(self._dataset_args) - tmp_args["path"] = dataset_path - tmp_dataset = self._construct_dataset(tmp_args) - all_indices.extend(tmp_dataset._get_all_indices()) - else: - all_indices = self._dataset._get_all_indices() - - return all_indices - def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: """Construct dataset. From 58d2f2740bda1150591457f7e9b252088bc39ad6 Mon Sep 17 00:00:00 2001 From: samadpls Date: Wed, 7 Feb 2024 01:03:03 +0500 Subject: [PATCH 07/14] Refactor GraphNeTDataModule and add unit test for save_selection function --- src/graphnet/data/datamodule.py | 51 +++++++++++++++++++-------------- src/graphnet/training/utils.py | 10 ++++--- tests/data/test_datamodule.py | 33 +++++++++++++++++++++ 3 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 tests/data/test_datamodule.py diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index c8268c990..f4db88e88 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -23,9 +23,9 @@ class GraphNeTDataModule(pl.LightningDataModule, Logger): def __init__( self, dataset_reference: Union[SQLiteDataset, ParquetDataset, Dataset], - selection: Optional[Union[List[int], List[List[int]]]], - test_selection: Optional[Union[List[int], List[List[int]]]], dataset_args: Dict[str, Any], + selection: Optional[Union[List[int], List[List[int]]]] = None, + test_selection: Optional[Union[List[int], List[List[int]]]] = None, train_dataloader_kwargs: Optional[Dict[str, Any]] = None, validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, test_dataloader_kwargs: Optional[Dict[str, Any]] = None, @@ -36,20 +36,22 @@ def __init__( Args: dataset_reference: A non-instantiated reference to the dataset class. + dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. selection: (Optional) a list of event id's used for training and validation. test_selection: (Optional) a list of event id's used for testing. - dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. train_dataloader_kwargs: Arguments for the training DataLoader. validation_dataloader_kwargs: Arguments for the validation DataLoader. test_dataloader_kwargs: Arguments for the test DataLoader. train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. split_seed: seed used for shuffling and splitting selections into train/validation. """ + Logger.__init__(self) + self._make_sure_root_logger_is_configured() self._dataset = dataset_reference - self._selection = selection or [0] - self._train_val_split = train_val_split or [0.0] - self._test_selection = test_selection or [0.0] self._dataset_args = dataset_args + self._selection = selection + self._test_selection = test_selection + self._train_val_split = train_val_split or [0.0] self._rng = split_seed self._train_dataloader_kwargs = train_dataloader_kwargs or {} @@ -61,6 +63,8 @@ def __init__( self._dataset_args["path"], list ) + self.setup("") + def prepare_data(self) -> None: """Prepare the dataset for training.""" # Download method for curated datasets. Method for download is @@ -82,9 +86,10 @@ def setup(self, stage: str) -> None: self._resolve_selections() # Creation of Datasets + # self._dataset = self._create_dataset(self.) self._train_dataset = self._create_dataset(self._train_selection) self._val_dataset = self._create_dataset(self._val_selection) - self._test_dataset = self._create_dataset(self._test_selection) + self._test_dataset = self._create_dataset(self._test_selection) # type: ignore return @@ -169,12 +174,14 @@ def _validate_dataset_class(self) -> None: ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset type is detected, or if an EnsembleDataset is used. """ - if not isinstance( - self._dataset, (SQLiteDataset, ParquetDataset, Dataset) - ): - raise TypeError( - "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset." - ) + print(self._dataset, "Dataset\n") + print( + f"Type of self._dataset before validation check: {type(self._dataset)}" + ) + # if type(self._dataset) not in [SQLiteDataset, ParquetDataset, Dataset]: + # raise TypeError( + # "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset." + # ) if isinstance(self._dataset, EnsembleDataset): raise TypeError( "EnsembleDataset is not allowed as dataset_reference." @@ -250,7 +257,7 @@ def _resolve_selections(self) -> None: self._selection ) - if self._selection is None: + else: # selection is None # If not provided, we infer it by grabbing all event ids in the dataset. self.info( f"{self.__class__.__name__} did not receive an argument for `selection`. Selection will automatically be created with a split of train: {self._train_val_split[0]} and validation: {self._train_val_split[1]}" @@ -258,7 +265,7 @@ def _resolve_selections(self) -> None: ( self._train_selection, self._val_selection, - ) = self._infer_selections() + ) = self._infer_selections() # type: ignore def _split_selection( self, selection: Union[int, List[int], List[List[int]]] @@ -336,16 +343,15 @@ def _infer_selections_on_single_dataset( all_events = ( tmp_dataset._get_all_indices() - ) # unshuffled list, # sequential index + ) # unshuffled list, sequential index # Multiple lines to avoid one large - all_events = pd.DataFrame(all_events).sample( - frac=1, replace=False, random_state=self._rng - ) - - all_events = random.sample( - all_events, len(all_events) + all_events = ( + pd.DataFrame(all_events) + .sample(frac=1, replace=False, random_state=self._rng) + .values.tolist() ) # shuffled list + return self._split_selection(all_events) def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: @@ -354,6 +360,7 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: Return: Dataset object constructed from input arguments. """ + print(tmp_args, "temp argument") dataset = self._dataset(**tmp_args) return dataset diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index b33089ec9..fca4a21e0 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -326,7 +326,9 @@ def save_selection(selection: List[int], file_path: str) -> None: selection: List of event ids. file_path: File path to save the selection. """ - with open(file_path, "w") as file: - file.write("event_id\n") - for event_id in selection: - file.write(f"{event_id}\n") + assert isinstance( + selection, list + ), "Selection should be a list of integers." + with open(file_path, "w") as f: + f.write(",".join(map(str, selection))) + f.write("\n") diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py new file mode 100644 index 000000000..e826bf189 --- /dev/null +++ b/tests/data/test_datamodule.py @@ -0,0 +1,33 @@ +"""Unit tests for DataModule.""" + +from typing import Union, Dict, Any, List + +import os +import pandas as pd +import pytest +from graphnet.data.constants import FEATURES, TRUTH + +from graphnet.training.utils import save_selection + + +@pytest.fixture +def selection() -> List[int]: + """Return a selection.""" + return [1, 2, 3, 4, 5] + + +@pytest.fixture +def file_path(tmpdir: str) -> str: + """Return a file path.""" + return os.path.join(tmpdir, "selection.csv") + + +def test_save_selection(selection: List[int], file_path: str) -> None: + """Test `save_selection` function.""" + save_selection(selection, file_path) + + assert os.path.exists(file_path) + + with open(file_path, "r") as f: + content = f.read() + assert content.strip() == "1,2,3,4,5" From 71a798b2063a178e978b5477d02be24baef0ef3d Mon Sep 17 00:00:00 2001 From: samadpls Date: Wed, 7 Feb 2024 17:09:27 +0500 Subject: [PATCH 08/14] Refactored `GraphNeTDataModule` and add test cases for `without_selection` Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 28 ++++++---- tests/data/test_datamodule.py | 98 +++++++++++++++++++++++++++++++-- 2 files changed, 110 insertions(+), 16 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index f4db88e88..e35cfba2e 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -37,13 +37,13 @@ def __init__( Args: dataset_reference: A non-instantiated reference to the dataset class. dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. - selection: (Optional) a list of event id's used for training and validation. - test_selection: (Optional) a list of event id's used for testing. - train_dataloader_kwargs: Arguments for the training DataLoader. - validation_dataloader_kwargs: Arguments for the validation DataLoader. - test_dataloader_kwargs: Arguments for the test DataLoader. + selection: (Optional) a list of event id's used for training and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, Default None. + validation_dataloader_kwargs: Arguments for the validation DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, Default None. train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. - split_seed: seed used for shuffling and splitting selections into train/validation. + split_seed: seed used for shuffling and splitting selections into train/validation, Default 42. """ Logger.__init__(self) self._make_sure_root_logger_is_configured() @@ -63,7 +63,7 @@ def __init__( self._dataset_args["path"], list ) - self.setup("") + self.setup("fit") def prepare_data(self) -> None: """Prepare the dataset for training.""" @@ -86,10 +86,11 @@ def setup(self, stage: str) -> None: self._resolve_selections() # Creation of Datasets - # self._dataset = self._create_dataset(self.) - self._train_dataset = self._create_dataset(self._train_selection) - self._val_dataset = self._create_dataset(self._val_selection) - self._test_dataset = self._create_dataset(self._test_selection) # type: ignore + if stage == "fit" or stage == "validate": + self._train_dataset = self._create_dataset(self._train_selection) + self._val_dataset = self._create_dataset(self._val_selection) + elif stage == "test": + self._test_dataset = self._create_dataset(self._test_selection) # type: ignore return @@ -165,6 +166,9 @@ def _create_dataloader( "Unknown dataset encountered during dataloader creation." ) + if dataloader_args is None: + raise AttributeError("Dataloader arguments not provided.") + return DataLoader(dataset=dataset, **dataloader_args) def _validate_dataset_class(self) -> None: @@ -361,7 +365,7 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: Dataset object constructed from input arguments. """ print(tmp_args, "temp argument") - dataset = self._dataset(**tmp_args) + dataset = self._dataset(**tmp_args) # type: ignore return dataset def _create_dataset( diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index e826bf189..946dd8906 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,15 +1,66 @@ """Unit tests for DataModule.""" -from typing import Union, Dict, Any, List - import os -import pandas as pd +from typing import List, Any + import pytest -from graphnet.data.constants import FEATURES, TRUTH +from torch.utils.data import SequentialSampler +from graphnet.constants import EXAMPLE_DATA_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.data.dataset import SQLiteDataset, ParquetDataset +from graphnet.data.datamodule import GraphNeTDataModule +from graphnet.models.detector import IceCubeDeepCore +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodesAsPulses from graphnet.training.utils import save_selection +@pytest.fixture +def dataset_ref(request: pytest.FixtureRequest) -> pytest.FixtureRequest: + """Return the dataset reference.""" + return request.param + + +@pytest.fixture +def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: + """Set up the dataset for testing. + + Args: + dataset_ref: The dataset reference. + + Returns: + A tuple with the dataset reference, dataset kwargs, and dataloader kwargs. + """ + # Grab public dataset paths + data_path = ( + f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db" + if dataset_ref is SQLiteDataset + else f"{EXAMPLE_DATA_DIR}/parquet/prometheus/prometheus-events.parquet" + ) + + # Setup basic inputs; can be altered by individual tests + graph_definition = KNNGraph( + detector=IceCubeDeepCore(), + node_definition=NodesAsPulses(), + nb_nearest_neighbours=8, + input_feature_names=FEATURES.DEEPCORE, + ) + + dataset_kwargs = { + "truth_table": "mc_truth", + "pulsemaps": "total", + "truth": TRUTH.PROMETHEUS, + "features": FEATURES.PROMETHEUS, + "path": data_path, + "graph_definition": graph_definition, + } + + dataloader_kwargs = {"batch_size": 2, "num_workers": 1} + + return dataset_ref, dataset_kwargs, dataloader_kwargs + + @pytest.fixture def selection() -> List[int]: """Return a selection.""" @@ -31,3 +82,42 @@ def test_save_selection(selection: List[int], file_path: str) -> None: with open(file_path, "r") as f: content = f.read() assert content.strip() == "1,2,3,4,5" + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_without_selections( + dataset_setup: tuple[Any, dict[str, Any], dict[str, int]] +) -> None: + """Verify GraphNeTDataModule behavior when no test selection is provided. + + Args: + dataset_setup: Tuple with dataset reference, dataset arguments, and dataloader arguments. + + Raises: + Exception: If the test dataloader is accessed without providing a test selection. + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + train_dataloader = dm.train_dataloader() + val_dataloader = dm.val_dataloader() + print(dm.test_dataloader, "here") + + with pytest.raises(Exception): + # should fail because we provided no test selection + test_dataloader = dm.test_dataloader() # noqa + # validation loader should have shuffle = False by default + assert isinstance(val_dataloader.sampler, SequentialSampler) + # Should have identical batch_size + assert val_dataloader.batch_size != train_dataloader.batch_size + # Training dataloader should contain more batches + assert len(train_dataloader) > len(val_dataloader) From 1d257d75825b4fb1791aa0c9fc9e3195c996385d Mon Sep 17 00:00:00 2001 From: samadpls Date: Wed, 7 Feb 2024 17:34:22 +0500 Subject: [PATCH 09/14] used typing notations Signed-off-by: samadpls --- tests/data/test_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 946dd8906..9c291929f 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,7 +1,7 @@ """Unit tests for DataModule.""" import os -from typing import List, Any +from typing import List, Any, Dict, Tuple import pytest from torch.utils.data import SequentialSampler @@ -88,7 +88,7 @@ def test_save_selection(selection: List[int], file_path: str) -> None: "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True ) def test_single_dataset_without_selections( - dataset_setup: tuple[Any, dict[str, Any], dict[str, int]] + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] ) -> None: """Verify GraphNeTDataModule behavior when no test selection is provided. From af7dd4658732069ac948e2f8d0ad943be523ae80 Mon Sep 17 00:00:00 2001 From: samadpls Date: Wed, 7 Feb 2024 18:10:19 +0500 Subject: [PATCH 10/14] added unit test for `with_selection` use case Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 12 ++++-- tests/data/test_datamodule.py | 71 ++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index e35cfba2e..2a1c25d35 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -86,11 +86,15 @@ def setup(self, stage: str) -> None: self._resolve_selections() # Creation of Datasets - if stage == "fit" or stage == "validate": - self._train_dataset = self._create_dataset(self._train_selection) - self._val_dataset = self._create_dataset(self._val_selection) - elif stage == "test": + if self._test_selection is not None: self._test_dataset = self._create_dataset(self._test_selection) # type: ignore + if stage == "fit" or stage == "validate": + if self._train_selection is not None: + self._train_dataset = self._create_dataset( + self._train_selection + ) + if self._val_selection is not None: + self._val_dataset = self._create_dataset(self._val_selection) return diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 9c291929f..f1418bd4b 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -2,7 +2,8 @@ import os from typing import List, Any, Dict, Tuple - +import pandas as pd +import sqlite3 import pytest from torch.utils.data import SequentialSampler @@ -121,3 +122,71 @@ def test_single_dataset_without_selections( assert val_dataloader.batch_size != train_dataloader.batch_size # Training dataloader should contain more batches assert len(train_dataloader) > len(val_dataloader) + + +def extract_all_events_ids( + file_path: str, dataset_kwargs: Dict[str, Any] +) -> List[int]: + """Extract all available event ids.""" + if file_path.endswith(".parquet"): + selection = pd.read_parquet(file_path)["event_id"].to_numpy().tolist() + elif file_path.endswith(".db"): + with sqlite3.connect(file_path) as conn: + query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}' + selection = ( + pd.read_sql(query, conn)["event_no"].to_numpy().tolist() + ) + else: + raise AssertionError( + f"File extension not accepted: {file_path.split('.')[-1]}" + ) + return selection + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that selection functionality of DataModule behaves as expected. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset arguments, and dataloader arguments. + + Returns: + None + """ + # extract all events + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + test_selection = selection[0:10] + train_val_selection = selection[10:] + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=train_val_selection, + test_selection=test_selection, + ) + + train_dataloader = dm.train_dataloader() + val_dataloader = dm.val_dataloader() + test_dataloader = dm.test_dataloader() + + # Check that the training and validation dataloader contains + # the same number of events as was given in the selection. + assert len(train_dataloader.dataset) + len(val_dataloader.dataset) == len(train_val_selection) # type: ignore + # Check that the number of events in the test dataset is equal to the + # number of events given in the selection. + assert len(test_dataloader.dataset) == len(test_selection) # type: ignore + # Training dataloader should have more batches + assert len(train_dataloader) > len(val_dataloader) From 49e72ac7cd3573d41d7142555b2fa19c4bebd74f Mon Sep 17 00:00:00 2001 From: samadpls Date: Wed, 7 Feb 2024 22:51:51 +0500 Subject: [PATCH 11/14] Refactored `dataloader` arguments Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 17 +++--- tests/data/test_datamodule.py | 94 +++++++++++++++++++++++---------- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 2a1c25d35..92751b092 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -5,7 +5,6 @@ from copy import deepcopy from sklearn.model_selection import train_test_split import pandas as pd -import random from graphnet.data.dataset import ( Dataset, @@ -14,7 +13,6 @@ ParquetDataset, ) from graphnet.utilities.logging import Logger -from graphnet.training.utils import save_selection class GraphNeTDataModule(pl.LightningDataModule, Logger): @@ -86,7 +84,10 @@ def setup(self, stage: str) -> None: self._resolve_selections() # Creation of Datasets - if self._test_selection is not None: + if ( + self._test_selection is not None + or len(self._test_dataloader_kwargs) > 0 + ): self._test_dataset = self._create_dataset(self._test_selection) # type: ignore if stage == "fit" or stage == "validate": if self._train_selection is not None: @@ -98,7 +99,8 @@ def setup(self, stage: str) -> None: return - def train_dataloader(self) -> DataLoader: + @property + def train_dataloader(self) -> DataLoader: # type: ignore[override] """Prepare and return the training DataLoader. Returns: @@ -106,7 +108,8 @@ def train_dataloader(self) -> DataLoader: """ return self._create_dataloader(self._train_dataset) - def val_dataloader(self) -> DataLoader: + @property + def val_dataloader(self) -> DataLoader: # type: ignore[override] """Prepare and return the validation DataLoader. Returns: @@ -114,7 +117,8 @@ def val_dataloader(self) -> DataLoader: """ return self._create_dataloader(self._val_dataset) - def test_dataloader(self) -> DataLoader: + @property + def test_dataloader(self) -> DataLoader: # type: ignore[override] """Prepare and return the test DataLoader. Returns: @@ -369,6 +373,7 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: Dataset object constructed from input arguments. """ print(tmp_args, "temp argument") + print(self._dataset, "<-dataset") dataset = self._dataset(**tmp_args) # type: ignore return dataset diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index f1418bd4b..9dab2b1d1 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,5 +1,6 @@ """Unit tests for DataModule.""" +from copy import deepcopy import os from typing import List, Any, Dict, Tuple import pandas as pd @@ -17,6 +18,25 @@ from graphnet.training.utils import save_selection +def extract_all_events_ids( + file_path: str, dataset_kwargs: Dict[str, Any] +) -> List[int]: + """Extract all available event ids.""" + if file_path.endswith(".parquet"): + selection = pd.read_parquet(file_path)["event_id"].to_numpy().tolist() + elif file_path.endswith(".db"): + with sqlite3.connect(file_path) as conn: + query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}' + selection = ( + pd.read_sql(query, conn)["event_no"].to_numpy().tolist() + ) + else: + raise AssertionError( + f"File extension not accepted: {file_path.split('.')[-1]}" + ) + return selection + + @pytest.fixture def dataset_ref(request: pytest.FixtureRequest) -> pytest.FixtureRequest: """Return the dataset reference.""" @@ -109,13 +129,12 @@ def test_single_dataset_without_selections( train_dataloader_kwargs=dataloader_kwargs, ) - train_dataloader = dm.train_dataloader() - val_dataloader = dm.val_dataloader() - print(dm.test_dataloader, "here") + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader with pytest.raises(Exception): # should fail because we provided no test selection - test_dataloader = dm.test_dataloader() # noqa + test_dataloader = dm.test_dataloader # noqa # validation loader should have shuffle = False by default assert isinstance(val_dataloader.sampler, SequentialSampler) # Should have identical batch_size @@ -124,25 +143,6 @@ def test_single_dataset_without_selections( assert len(train_dataloader) > len(val_dataloader) -def extract_all_events_ids( - file_path: str, dataset_kwargs: Dict[str, Any] -) -> List[int]: - """Extract all available event ids.""" - if file_path.endswith(".parquet"): - selection = pd.read_parquet(file_path)["event_id"].to_numpy().tolist() - elif file_path.endswith(".db"): - with sqlite3.connect(file_path) as conn: - query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}' - selection = ( - pd.read_sql(query, conn)["event_no"].to_numpy().tolist() - ) - else: - raise AssertionError( - f"File extension not accepted: {file_path.split('.')[-1]}" - ) - return selection - - @pytest.mark.parametrize( "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True ) @@ -158,8 +158,8 @@ def test_single_dataset_with_selections( Returns: None """ - # extract all events dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + # extract all events file_path = dataset_kwargs["path"] selection = extract_all_events_ids( file_path=file_path, dataset_kwargs=dataset_kwargs @@ -178,9 +178,9 @@ def test_single_dataset_with_selections( test_selection=test_selection, ) - train_dataloader = dm.train_dataloader() - val_dataloader = dm.val_dataloader() - test_dataloader = dm.test_dataloader() + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader + test_dataloader = dm.test_dataloader # Check that the training and validation dataloader contains # the same number of events as was given in the selection. @@ -190,3 +190,43 @@ def test_single_dataset_with_selections( assert len(test_dataloader.dataset) == len(test_selection) # type: ignore # Training dataloader should have more batches assert len(train_dataloader) > len(val_dataloader) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_dataloader_args( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that arguments to dataloaders are propagated correctly. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + val_dataloader_kwargs = deepcopy(dataloader_kwargs) + test_dataloader_kwargs = deepcopy(dataloader_kwargs) + + # Setting batch sizes to different values + val_dataloader_kwargs["batch_size"] = 1 + test_dataloader_kwargs["batch_size"] = 2 + dataloader_kwargs["batch_size"] = 3 + + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + validation_dataloader_kwargs=val_dataloader_kwargs, + test_dataloader_kwargs=test_dataloader_kwargs, + ) + + # Check that the resulting dataloaders have the right batch sizes + assert dm.train_dataloader.batch_size == dataloader_kwargs["batch_size"] + assert dm.val_dataloader.batch_size == val_dataloader_kwargs["batch_size"] + assert ( + dm.test_dataloader.batch_size == test_dataloader_kwargs["batch_size"] + ) From 57f49fe4191067d5ea4fae64a37c8c171fb46457 Mon Sep 17 00:00:00 2001 From: samadpls Date: Fri, 9 Feb 2024 20:46:20 +0500 Subject: [PATCH 12/14] Fix ensemble dataset functionality and added unit tests Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 4 +- tests/data/test_datamodule.py | 102 ++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 92751b092..8c5aa7aeb 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -325,8 +325,8 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]: train_selection, val_selection, ) = self._infer_selections_on_single_dataset(dataset_path) - self._train_selection.extend(train_selection) # type: ignore - self._val_selection.extend(val_selection) # type: ignore + self._train_selection.append(train_selection) # type: ignore + self._val_selection.append(val_selection) # type: ignore else: # Infer selection on a single dataset ( diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 9dab2b1d1..9f8a1b745 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -230,3 +230,105 @@ def test_dataloader_args( assert ( dm.test_dataloader.batch_size == test_dataloader_kwargs["batch_size"] ) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_ensemble_dataset_without_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality without selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # Make dataloaders from single dataset + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + dm_single = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=deepcopy(dataset_kwargs), + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # Create dataloaders from multiple datasets + dm_ensemble = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Test that the ensemble dataloaders contain more batches + assert len(dm_single.train_dataloader) < len(dm_ensemble.train_dataloader) + assert len(dm_single.val_dataloader) < len(dm_ensemble.val_dataloader) + + +@pytest.mark.parametrize("dataset_ref", [SQLiteDataset, ParquetDataset]) +def test_ensemble_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality with selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # extract all events + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # pass two datasets but only one selection; should fail: + with pytest.raises(Exception): + _ = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=selection, + ) + + # Pass two datasets and two selections; should work: + selection_1 = selection[0:20] + selection_2 = selection[0:10] + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection_1, selection_2], + ) + n_events_in_dataloaders = len(dm.train_dataloader.dataset) + len(dm.val_dataloader.dataset) # type: ignore + + # Check that the number of events in train/val match + assert n_events_in_dataloaders == len(selection_1) + len(selection_2) + + # Pass two datasets, two selections and two test selections; should work + dm2 = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection, selection], + test_selection=[selection_1, selection_2], + ) + + # Check that the number of events in test dataloaders are correct. + n_events_in_test_dataloaders = len(dm2.test_dataloader.dataset) # type: ignore + assert n_events_in_test_dataloaders == len(selection_1) + len(selection_2) From 9f665874b9ffc02c67097e321a6d3185d9150fa7 Mon Sep 17 00:00:00 2001 From: samadpls Date: Mon, 12 Feb 2024 19:54:00 +0500 Subject: [PATCH 13/14] Refactored `GraphNeTDataModule` class Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 76 +++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 8c5aa7aeb..6c85fb061 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -34,14 +34,22 @@ def __init__( Args: dataset_reference: A non-instantiated reference to the dataset class. - dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. - selection: (Optional) a list of event id's used for training and validation, Default None. - test_selection: (Optional) a list of event id's used for testing, Default None. - train_dataloader_kwargs: Arguments for the training DataLoader, Default None. - validation_dataloader_kwargs: Arguments for the validation DataLoader, Default None. - test_dataloader_kwargs: Arguments for the test DataLoader, Default None. - train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. - split_seed: seed used for shuffling and splitting selections into train/validation, Default 42. + dataset_args: Arguments to instantiate + graphnet.data.dataset.Dataset with. + selection: (Optional) a list of event id's used for training + and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, + Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, + Default None. + validation_dataloader_kwargs: Arguments for the validation + DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, + Default None. + train_val_split (Optional): Split ratio for training and + validation sets. Default is [0.9, 0.10]. + split_seed: seed used for shuffling and splitting selections into + train/validation, Default 42. """ Logger.__init__(self) self._make_sure_root_logger_is_configured() @@ -158,7 +166,8 @@ def _create_dataloader( """Create a DataLoader for the given dataset. Args: - dataset (Union[Dataset, EnsembleDataset]): The dataset to create a DataLoader for. + dataset (Union[Dataset, EnsembleDataset]): + The dataset to create a DataLoader for. Returns: DataLoader: The DataLoader configured for the given dataset. @@ -186,15 +195,13 @@ def _validate_dataset_class(self) -> None: ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset type is detected, or if an EnsembleDataset is used. """ - print(self._dataset, "Dataset\n") - print( - f"Type of self._dataset before validation check: {type(self._dataset)}" - ) - # if type(self._dataset) not in [SQLiteDataset, ParquetDataset, Dataset]: - # raise TypeError( - # "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset." - # ) - if isinstance(self._dataset, EnsembleDataset): + allowed_types = (SQLiteDataset, ParquetDataset, Dataset) + if self._dataset not in allowed_types: + raise TypeError( + "dataset_reference must be an instance " + "of SQLiteDataset, ParquetDataset, or Dataset." + ) + if self._dataset is EnsembleDataset: raise TypeError( "EnsembleDataset is not allowed as dataset_reference." ) @@ -211,7 +218,10 @@ def _validate_dataset_args(self) -> None: ) except AssertionError: raise ValueError( - f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of selections ({len(self._selection)})." + "The number of dataset paths" + f" ({len(self._dataset_args['path'])})" + " does not match the number of" + f" selections ({len(self._selection)})." ) if self._test_selection is not None: @@ -223,7 +233,13 @@ def _validate_dataset_args(self) -> None: ) except AssertionError: raise ValueError( - f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of test selections ({len(self._test_selection)}). If you'd like to test on only a subset of the {len(self._dataset_args['path'])} datasets, please provide empty test selections for the others." + "The number of dataset paths " + f" ({len(self._dataset_args['path'])}) does not match " + "the number of test selections " + f"({len(self._test_selection)}).If you'd like to test " + "on only a subset of the " + f"{len(self._dataset_args['path'])} datasets, " + "please provide empty test selections for the others." ) def _validate_dataloader_args(self) -> None: @@ -244,7 +260,9 @@ def _validate_dataloader_args(self) -> None: def _resolve_selections(self) -> None: if self._test_selection is None: self.warning_once( - f"{self.__class__.__name__} did not receive an argument for `test_selection` and will therefore not have a prediction dataloader available." + f"{self.__class__.__name__} did not receive an" + " argument for `test_selection` and will " + "therefore not have a prediction dataloader available." ) if self._selection is not None: # Split the selection into train/validation @@ -270,9 +288,14 @@ def _resolve_selections(self) -> None: ) else: # selection is None - # If not provided, we infer it by grabbing all event ids in the dataset. + # If not provided, we infer it by grabbing + # all event ids in the dataset. self.info( - f"{self.__class__.__name__} did not receive an argument for `selection`. Selection will automatically be created with a split of train: {self._train_val_split[0]} and validation: {self._train_val_split[1]}" + f"{self.__class__.__name__} did not receive an" + " for `selection`. Selection will " + "will automatically be created with a split of " + f"train: {self._train_val_split[0]} and " + f"validation: {self._train_val_split[1]}" ) ( self._train_selection, @@ -372,8 +395,6 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: Return: Dataset object constructed from input arguments. """ - print(tmp_args, "temp argument") - print(self._dataset, "<-dataset") dataset = self._dataset(**tmp_args) # type: ignore return dataset @@ -390,7 +411,7 @@ def _create_dataset( """ if self._use_ensemble_dataset: # Construct multiple datasets and pass to EnsembleDataset - # At this point, we have checked that len(selection) == len(dataset_args['path']) + # len(selection) == len(dataset_args['path']) datasets = [] for dataset_idx in range(len(selection)): datasets.append( @@ -405,7 +426,8 @@ def _create_dataset( else: # Construct single dataset dataset = self._create_single_dataset( - selection=selection, path=self._dataset_args["path"] # type: ignore + selection=selection, + path=self._dataset_args["path"], # type:ignore ) return dataset From 33fc7d02707e78af43c67932eddc121997313da4 Mon Sep 17 00:00:00 2001 From: samadpls Date: Mon, 12 Feb 2024 20:02:06 +0500 Subject: [PATCH 14/14] Refactored `dataset_reference` argument in GraphNeTDataModule Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 6c85fb061..e629ce4a0 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -33,7 +33,8 @@ def __init__( """Create dataloaders from dataset. Args: - dataset_reference: A non-instantiated reference to the dataset class. + dataset_reference: A non-instantiated reference + to the dataset class. dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. selection: (Optional) a list of event id's used for training @@ -96,7 +97,9 @@ def setup(self, stage: str) -> None: self._test_selection is not None or len(self._test_dataloader_kwargs) > 0 ): - self._test_dataset = self._create_dataset(self._test_selection) # type: ignore + self._test_dataset = self._create_dataset( + self._test_selection # type: ignore + ) if stage == "fit" or stage == "validate": if self._train_selection is not None: self._train_dataset = self._create_dataset( @@ -318,7 +321,9 @@ def _split_selection( flat_selection = [selection] elif isinstance(selection[0], list): flat_selection = [ - item for sublist in selection for item in sublist # type: ignore + item + for sublist in selection + for item in sublist # type: ignore ] else: flat_selection = selection # type: ignore