Skip to content

Commit

Permalink
Refactored GraphNeTDataModule class
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Feb 12, 2024
1 parent 57f49fe commit 9f66587
Showing 1 changed file with 49 additions and 27 deletions.
76 changes: 49 additions & 27 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ def __init__(
Args:
dataset_reference: A non-instantiated reference to the dataset class.
dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with.
selection: (Optional) a list of event id's used for training and validation, Default None.
test_selection: (Optional) a list of event id's used for testing, Default None.
train_dataloader_kwargs: Arguments for the training DataLoader, Default None.
validation_dataloader_kwargs: Arguments for the validation DataLoader, Default None.
test_dataloader_kwargs: Arguments for the test DataLoader, Default None.
train_val_split (Optional): Split ratio for training and validation sets. Default is [0.9, 0.10].
split_seed: seed used for shuffling and splitting selections into train/validation, Default 42.
dataset_args: Arguments to instantiate
graphnet.data.dataset.Dataset with.
selection: (Optional) a list of event id's used for training
and validation, Default None.
test_selection: (Optional) a list of event id's used for testing,
Default None.
train_dataloader_kwargs: Arguments for the training DataLoader,
Default None.
validation_dataloader_kwargs: Arguments for the validation
DataLoader, Default None.
test_dataloader_kwargs: Arguments for the test DataLoader,
Default None.
train_val_split (Optional): Split ratio for training and
validation sets. Default is [0.9, 0.10].
split_seed: seed used for shuffling and splitting selections into
train/validation, Default 42.
"""
Logger.__init__(self)
self._make_sure_root_logger_is_configured()
Expand Down Expand Up @@ -158,7 +166,8 @@ def _create_dataloader(
"""Create a DataLoader for the given dataset.
Args:
dataset (Union[Dataset, EnsembleDataset]): The dataset to create a DataLoader for.
dataset (Union[Dataset, EnsembleDataset]):
The dataset to create a DataLoader for.
Returns:
DataLoader: The DataLoader configured for the given dataset.
Expand Down Expand Up @@ -186,15 +195,13 @@ def _validate_dataset_class(self) -> None:
ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset
type is detected, or if an EnsembleDataset is used.
"""
print(self._dataset, "Dataset\n")
print(
f"Type of self._dataset before validation check: {type(self._dataset)}"
)
# if type(self._dataset) not in [SQLiteDataset, ParquetDataset, Dataset]:
# raise TypeError(
# "dataset_reference must be an instance of SQLiteDataset, ParquetDataset, or Dataset."
# )
if isinstance(self._dataset, EnsembleDataset):
allowed_types = (SQLiteDataset, ParquetDataset, Dataset)
if self._dataset not in allowed_types:
raise TypeError(
"dataset_reference must be an instance "
"of SQLiteDataset, ParquetDataset, or Dataset."
)
if self._dataset is EnsembleDataset:
raise TypeError(
"EnsembleDataset is not allowed as dataset_reference."
)
Expand All @@ -211,7 +218,10 @@ def _validate_dataset_args(self) -> None:
)
except AssertionError:
raise ValueError(
f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of selections ({len(self._selection)})."
"The number of dataset paths"
f" ({len(self._dataset_args['path'])})"
" does not match the number of"
f" selections ({len(self._selection)})."
)

if self._test_selection is not None:
Expand All @@ -223,7 +233,13 @@ def _validate_dataset_args(self) -> None:
)
except AssertionError:
raise ValueError(
f"The number of dataset paths ({len(self._dataset_args['path'])}) does not match the number of test selections ({len(self._test_selection)}). If you'd like to test on only a subset of the {len(self._dataset_args['path'])} datasets, please provide empty test selections for the others."
"The number of dataset paths "
f" ({len(self._dataset_args['path'])}) does not match "
"the number of test selections "
f"({len(self._test_selection)}).If you'd like to test "
"on only a subset of the "
f"{len(self._dataset_args['path'])} datasets, "
"please provide empty test selections for the others."
)

def _validate_dataloader_args(self) -> None:
Expand All @@ -244,7 +260,9 @@ def _validate_dataloader_args(self) -> None:
def _resolve_selections(self) -> None:
if self._test_selection is None:
self.warning_once(
f"{self.__class__.__name__} did not receive an argument for `test_selection` and will therefore not have a prediction dataloader available."
f"{self.__class__.__name__} did not receive an"
" argument for `test_selection` and will "
"therefore not have a prediction dataloader available."
)
if self._selection is not None:
# Split the selection into train/validation
Expand All @@ -270,9 +288,14 @@ def _resolve_selections(self) -> None:
)

else: # selection is None
# If not provided, we infer it by grabbing all event ids in the dataset.
# If not provided, we infer it by grabbing
# all event ids in the dataset.
self.info(
f"{self.__class__.__name__} did not receive an argument for `selection`. Selection will automatically be created with a split of train: {self._train_val_split[0]} and validation: {self._train_val_split[1]}"
f"{self.__class__.__name__} did not receive an"
" for `selection`. Selection will "
"will automatically be created with a split of "
f"train: {self._train_val_split[0]} and "
f"validation: {self._train_val_split[1]}"
)
(
self._train_selection,
Expand Down Expand Up @@ -372,8 +395,6 @@ def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset:
Return:
Dataset object constructed from input arguments.
"""
print(tmp_args, "temp argument")
print(self._dataset, "<-dataset")
dataset = self._dataset(**tmp_args) # type: ignore
return dataset

Expand All @@ -390,7 +411,7 @@ def _create_dataset(
"""
if self._use_ensemble_dataset:
# Construct multiple datasets and pass to EnsembleDataset
# At this point, we have checked that len(selection) == len(dataset_args['path'])
# len(selection) == len(dataset_args['path'])
datasets = []
for dataset_idx in range(len(selection)):
datasets.append(
Expand All @@ -405,7 +426,8 @@ def _create_dataset(
else:
# Construct single dataset
dataset = self._create_single_dataset(
selection=selection, path=self._dataset_args["path"] # type: ignore
selection=selection,
path=self._dataset_args["path"], # type:ignore
)
return dataset

Expand Down

0 comments on commit 9f66587

Please sign in to comment.