From 9f665874b9ffc02c67097e321a6d3185d9150fa7 Mon Sep 17 00:00:00 2001 From: samadpls Date: Mon, 12 Feb 2024 19:54:00 +0500 Subject: [PATCH] Refactored `GraphNeTDataModule` class Signed-off-by: samadpls --- src/graphnet/data/datamodule.py | 76 +++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 8c5aa7aeb..6c85fb061 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -34,14 +34,22 @@ def __init__( Args: dataset_reference: A non-instantiated reference to the dataset class. - dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with. - selection: (Optional) a list of event id's used for training and validation, Default None. - test_selection: (Optional) a list of event id's used for testing, Default None. - train_dataloader_kwargs: Arguments for the training DataLoader, Default None. - validation_dataloader_kwargs: Arguments for the validation DataLoader, Default None. - test_dataloader_kwargs: Arguments for the test DataLoader, Default None. - train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10]. - split_seed: seed used for shuffling and splitting selections into train/validation, Default 42. + dataset_args: Arguments to instantiate + graphnet.data.dataset.Dataset with. + selection: (Optional) a list of event id's used for training + and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, + Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, + Default None. + validation_dataloader_kwargs: Arguments for the validation + DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, + Default None. + train_val_split (Optional): Split ratio for training and + validation sets. Default is [0.9, 0.10]. + split_seed: seed used for shuffling and splitting selections into + train/validation, Default 42. """ Logger.__init__(self) self._make_sure_root_logger_is_configured() @@ -158,7 +166,8 @@ def _create_dataloader( """Create a DataLoader for the given dataset. Args: - dataset (Union[Dataset, EnsembleDataset]): The dataset to create a DataLoader for. + dataset (Union[Dataset, EnsembleDataset]): + The dataset to create a DataLoader for. Returns: DataLoader: The DataLoader configured for the given dataset. @@ -186,15 +195,13 @@ def _validate_dataset_class(self) -> None: ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset type is detected, or if an EnsembleDataset is used. """ - print(self._dataset, "Dataset\n") - print( - f"Type of self._dataset before validation check: {type(self._dataset)}" - ) - # if type(self._dataset) not in [SQLiteDataset, ParquetDataset, Dataset]: - # raise TypeError( - # "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset." - # ) - if isinstance(self._dataset, EnsembleDataset): + allowed_types = (SQLiteDataset, ParquetDataset, Dataset) + if self._dataset not in allowed_types: + raise TypeError( + "dataset_reference must be an instance " + "of SQLiteDataset, ParquetDataset, or Dataset." + ) + if self._dataset is EnsembleDataset: raise TypeError( "EnsembleDataset is not allowed as dataset_reference." ) @@ -211,7 +218,10 @@ def _validate_dataset_args(self) -> None: ) except AssertionError: raise ValueError( - f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of selections ({len(self._selection)})." + "The number of dataset paths" + f" ({len(self._dataset_args['path'])})" + " does not match the number of" + f" selections ({len(self._selection)})." ) if self._test_selection is not None: @@ -223,7 +233,13 @@ def _validate_dataset_args(self) -> None: ) except AssertionError: raise ValueError( - f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of test selections ({len(self._test_selection)}). If you'd like to test on only a subset of the {len(self._dataset_args['path'])} datasets, please provide empty test selections for the others." + "The number of dataset paths " + f" ({len(self._dataset_args['path'])}) does not match " + "the number of test selections " + f"({len(self._test_selection)}).If you'd like to test " + "on only a subset of the " + f"{len(self._dataset_args['path'])} datasets, " + "please provide empty test selections for the others." ) def _validate_dataloader_args(self) -> None: @@ -244,7 +260,9 @@ def _validate_dataloader_args(self) -> None: def _resolve_selections(self) -> None: if self._test_selection is None: self.warning_once( - f"{self.__class__.__name__} did not receive an argument for `test_selection` and will therefore not have a prediction dataloader available." + f"{self.__class__.__name__} did not receive an" + " argument for `test_selection` and will " + "therefore not have a prediction dataloader available." ) if self._selection is not None: # Split the selection into train/validation @@ -270,9 +288,14 @@ def _resolve_selections(self) -> None: ) else: # selection is None - # If not provided, we infer it by grabbing all event ids in the dataset. + # If not provided, we infer it by grabbing + # all event ids in the dataset. self.info( - f"{self.__class__.__name__} did not receive an argument for `selection`. Selection will automatically be created with a split of train: {self._train_val_split[0]} and validation: {self._train_val_split[1]}" + f"{self.__class__.__name__} did not receive an" + " for `selection`. Selection will " + "will automatically be created with a split of " + f"train: {self._train_val_split[0]} and " + f"validation: {self._train_val_split[1]}" ) ( self._train_selection, @@ -372,8 +395,6 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: Return: Dataset object constructed from input arguments. """ - print(tmp_args, "temp argument") - print(self._dataset, "<-dataset") dataset = self._dataset(**tmp_args) # type: ignore return dataset @@ -390,7 +411,7 @@ def _create_dataset( """ if self._use_ensemble_dataset: # Construct multiple datasets and pass to EnsembleDataset - # At this point, we have checked that len(selection) == len(dataset_args['path']) + # len(selection) == len(dataset_args['path']) datasets = [] for dataset_idx in range(len(selection)): datasets.append( @@ -405,7 +426,8 @@ def _create_dataset( else: # Construct single dataset dataset = self._create_single_dataset( - selection=selection, path=self._dataset_args["path"] # type: ignore + selection=selection, + path=self._dataset_args["path"], # type:ignore ) return dataset