diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py new file mode 100644 index 000000000..e629ce4a0 --- /dev/null +++ b/src/graphnet/data/datamodule.py @@ -0,0 +1,456 @@ +"""Base `Dataloader` class(es) used in `graphnet`.""" +from typing import Dict, Any, Optional, List, Tuple, Union +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from copy import deepcopy +from sklearn.model_selection import train_test_split +import pandas as pd + +from graphnet.data.dataset import ( + Dataset, + EnsembleDataset, + SQLiteDataset, + ParquetDataset, +) +from graphnet.utilities.logging import Logger + + +class GraphNeTDataModule(pl.LightningDataModule, Logger): + """General Class for DataLoader Construction.""" + + def __init__( + self, + dataset_reference: Union[SQLiteDataset, ParquetDataset, Dataset], + dataset_args: Dict[str, Any], + selection: Optional[Union[List[int], List[List[int]]]] = None, + test_selection: Optional[Union[List[int], List[List[int]]]] = None, + train_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, + test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + train_val_split: Optional[List[float]] = [0.9, 0.10], + split_seed: int = 42, + ) -> None: + """Create dataloaders from dataset. + + Args: + dataset_reference: A non-instantiated reference + to the dataset class. + dataset_args: Arguments to instantiate + graphnet.data.dataset.Dataset with. + selection: (Optional) a list of event id's used for training + and validation, Default None. + test_selection: (Optional) a list of event id's used for testing, + Default None. + train_dataloader_kwargs: Arguments for the training DataLoader, + Default None. + validation_dataloader_kwargs: Arguments for the validation + DataLoader, Default None. + test_dataloader_kwargs: Arguments for the test DataLoader, + Default None. + train_val_split (Optional): Split ratio for training and + validation sets. Default is [0.9, 0.10]. + split_seed: seed used for shuffling and splitting selections into + train/validation, Default 42. + """ + Logger.__init__(self) + self._make_sure_root_logger_is_configured() + self._dataset = dataset_reference + self._dataset_args = dataset_args + self._selection = selection + self._test_selection = test_selection + self._train_val_split = train_val_split or [0.0] + self._rng = split_seed + + self._train_dataloader_kwargs = train_dataloader_kwargs or {} + self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} + self._test_dataloader_kwargs = test_dataloader_kwargs or {} + + # If multiple dataset paths are given, we should use EnsembleDataset + self._use_ensemble_dataset = isinstance( + self._dataset_args["path"], list + ) + + self.setup("fit") + + def prepare_data(self) -> None: + """Prepare the dataset for training.""" + # Download method for curated datasets. Method for download is + # likely dataset-specific, so we can leave it as-is + pass + + def setup(self, stage: str) -> None: + """Prepare Datasets for DataLoaders. + + Args: + stage: lightning stage. Either "fit, validate, test, predict" + """ + # Sanity Checks + self._validate_dataset_class() + self._validate_dataset_args() + self._validate_dataloader_args() + + # Case-handling of selection arguments + self._resolve_selections() + + # Creation of Datasets + if ( + self._test_selection is not None + or len(self._test_dataloader_kwargs) > 0 + ): + self._test_dataset = self._create_dataset( + self._test_selection # type: ignore + ) + if stage == "fit" or stage == "validate": + if self._train_selection is not None: + self._train_dataset = self._create_dataset( + self._train_selection + ) + if self._val_selection is not None: + self._val_dataset = self._create_dataset(self._val_selection) + + return + + @property + def train_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the training DataLoader. + + Returns: + DataLoader: The DataLoader configured for training. + """ + return self._create_dataloader(self._train_dataset) + + @property + def val_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the validation DataLoader. + + Returns: + DataLoader: The DataLoader configured for validation. + """ + return self._create_dataloader(self._val_dataset) + + @property + def test_dataloader(self) -> DataLoader: # type: ignore[override] + """Prepare and return the test DataLoader. + + Returns: + DataLoader: The DataLoader configured for testing. + """ + return self._create_dataloader(self._test_dataset) + + def teardown(self) -> None: # type: ignore[override] + """Perform any necessary cleanup or shutdown procedures. + + This method can be used for tasks such as closing SQLite connections + after training. Override this method as needed. + + Returns: + None + """ + if hasattr(self, "_train_dataset") and isinstance( + self._train_dataset, SQLiteDataset + ): + self._train_dataset._close_connection() + + if hasattr(self, "_val_dataset") and isinstance( + self._val_dataset, SQLiteDataset + ): + self._val_dataset._close_connection() + + if hasattr(self, "_test_dataset") and isinstance( + self._test_dataset, SQLiteDataset + ): + self._test_dataset._close_connection() + + return + + def _create_dataloader( + self, dataset: Union[Dataset, EnsembleDataset] + ) -> DataLoader: + """Create a DataLoader for the given dataset. + + Args: + dataset (Union[Dataset, EnsembleDataset]): + The dataset to create a DataLoader for. + + Returns: + DataLoader: The DataLoader configured for the given dataset. + """ + if dataset == self._train_dataset: + dataloader_args = self._train_dataloader_kwargs + elif dataset == self._val_dataset: + dataloader_args = self._validation_dataloader_kwargs + elif dataset == self._test_dataset: + dataloader_args = self._test_dataloader_kwargs + else: + raise ValueError( + "Unknown dataset encountered during dataloader creation." + ) + + if dataloader_args is None: + raise AttributeError("Dataloader arguments not provided.") + + return DataLoader(dataset=dataset, **dataloader_args) + + def _validate_dataset_class(self) -> None: + """Sanity checks on the dataset reference (self._dataset). + + Checks whether the dataset is an instance of SQLiteDataset, + ParquetDataset, or Dataset. Raises a TypeError if an invalid dataset + type is detected, or if an EnsembleDataset is used. + """ + allowed_types = (SQLiteDataset, ParquetDataset, Dataset) + if self._dataset not in allowed_types: + raise TypeError( + "dataset_reference must be an instance " + "of SQLiteDataset, ParquetDataset, or Dataset." + ) + if self._dataset is EnsembleDataset: + raise TypeError( + "EnsembleDataset is not allowed as dataset_reference." + ) + + def _validate_dataset_args(self) -> None: + """Sanity checks on the arguments for the dataset reference.""" + if isinstance(self._dataset_args["path"], list): + if self._selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of selections given as arg. + assert len(self._dataset_args["path"]) == len( + self._selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths" + f" ({len(self._dataset_args['path'])})" + " does not match the number of" + f" selections ({len(self._selection)})." + ) + + if self._test_selection is not None: + try: + # Check that the number of dataset paths is equal to the + # number of test selections. + assert len(self._dataset_args["path"]) == len( + self._test_selection + ) + except AssertionError: + raise ValueError( + "The number of dataset paths " + f" ({len(self._dataset_args['path'])}) does not match " + "the number of test selections " + f"({len(self._test_selection)}).If you'd like to test " + "on only a subset of the " + f"{len(self._dataset_args['path'])} datasets, " + "please provide empty test selections for the others." + ) + + def _validate_dataloader_args(self) -> None: + """Sanity check on `dataloader_args`.""" + if "dataset" in self._train_dataloader_kwargs: + raise ValueError( + "`train_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._validation_dataloader_kwargs: + raise ValueError( + "`validation_dataloader_kwargs` must not contain `dataset`" + ) + if "dataset" in self._test_dataloader_kwargs: + raise ValueError( + "`test_dataloader_kwargs` must not contain `dataset`" + ) + + def _resolve_selections(self) -> None: + if self._test_selection is None: + self.warning_once( + f"{self.__class__.__name__} did not receive an" + " argument for `test_selection` and will " + "therefore not have a prediction dataloader available." + ) + if self._selection is not None: + # Split the selection into train/validation + if self._use_ensemble_dataset: + # Split every selection + self._train_selection = [] + self._val_selection = [] + for selection in self._selection: + train_selection, val_selection = self._split_selection( + selection + ) + self._train_selection.append(train_selection) + self._val_selection.append(val_selection) + + else: + # Split the only selection we got + assert isinstance(self._selection, list) + ( + self._train_selection, + self._val_selection, + ) = self._split_selection( # type: ignore + self._selection + ) + + else: # selection is None + # If not provided, we infer it by grabbing + # all event ids in the dataset. + self.info( + f"{self.__class__.__name__} did not receive an" + " for `selection`. Selection will " + "will automatically be created with a split of " + f"train: {self._train_val_split[0]} and " + f"validation: {self._train_val_split[1]}" + ) + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections() # type: ignore + + def _split_selection( + self, selection: Union[int, List[int], List[List[int]]] + ) -> Tuple[List[int], List[int]]: + """Split train selection into train/validation. + + Args: + selection: Training selection to be split + + Returns: + Training selection, Validation selection. + """ + assert isinstance(selection, (int, list)) + if isinstance(selection, int): + flat_selection = [selection] + elif isinstance(selection[0], list): + flat_selection = [ + item + for sublist in selection + for item in sublist # type: ignore + ] + else: + flat_selection = selection # type: ignore + assert isinstance(flat_selection, list) + + train_selection, val_selection = train_test_split( + flat_selection, + train_size=self._train_val_split[0], + test_size=self._train_val_split[1], + random_state=self._rng, + ) + return train_selection, val_selection + + def _infer_selections(self) -> Tuple[List[int], List[int]]: + """Automatically infer training and validation selections. + + Returns: + Training selection, Validation selection + """ + if self._use_ensemble_dataset: + # We must iterate through the dataset paths and infer a train/val + # selection for each. + self._train_selection = [] + self._val_selection = [] + for dataset_path in self._dataset_args["path"]: + ( + train_selection, + val_selection, + ) = self._infer_selections_on_single_dataset(dataset_path) + self._train_selection.append(train_selection) # type: ignore + self._val_selection.append(val_selection) # type: ignore + else: + # Infer selection on a single dataset + ( + self._train_selection, + self._val_selection, + ) = self._infer_selections_on_single_dataset( # type: ignore + self._dataset_args["path"] + ) + + return (self._train_selection, self._val_selection) # type: ignore + + def _infer_selections_on_single_dataset( + self, dataset_path: str + ) -> Tuple[List[int], List[int]]: + """Automatically infers dataset train/val selections. + + Args: + dataset_path (str): The path to the dataset. + + Returns: + Tuple[List[int], List[int]]: Training and validation selections. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = dataset_path + tmp_dataset = self._construct_dataset(tmp_args) + + all_events = ( + tmp_dataset._get_all_indices() + ) # unshuffled list, sequential index + + # Multiple lines to avoid one large + all_events = ( + pd.DataFrame(all_events) + .sample(frac=1, replace=False, random_state=self._rng) + .values.tolist() + ) # shuffled list + + return self._split_selection(all_events) + + def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset: + """Construct dataset. + + Return: + Dataset object constructed from input arguments. + """ + dataset = self._dataset(**tmp_args) # type: ignore + return dataset + + def _create_dataset( + self, selection: Union[List[int], List[List[int]], List[float]] + ) -> Union[EnsembleDataset, Dataset]: + """Instantiate `dataset_reference`. + + Args: + selection: The selected event id's. + + Returns: + A dataset, either an instance of `EnsembleDataset` or `Dataset`. + """ + if self._use_ensemble_dataset: + # Construct multiple datasets and pass to EnsembleDataset + # len(selection) == len(dataset_args['path']) + datasets = [] + for dataset_idx in range(len(selection)): + datasets.append( + self._create_single_dataset( + selection=selection[dataset_idx], # type: ignore + path=self._dataset_args["path"][dataset_idx], + ) + ) + + dataset = EnsembleDataset(datasets) + + else: + # Construct single dataset + dataset = self._create_single_dataset( + selection=selection, + path=self._dataset_args["path"], # type:ignore + ) + return dataset + + def _create_single_dataset( + self, + selection: Union[List[int], List[List[int]], List[float]], + path: str, + ) -> Dataset: + """Instantiate a single `Dataset`. + + Args: + selection: A selection for a single dataset. + path: Path to a single dataset + + Returns: + An instance of `Dataset`. + """ + tmp_args = deepcopy(self._dataset_args) + tmp_args["path"] = path + tmp_args["selection"] = selection + return self._construct_dataset(tmp_args) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index df7c92e15..fca4a21e0 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -317,3 +317,18 @@ def save_results( model.save_state_dict(path + "/" + tag + "_state_dict.pth") model.save(path + "/" + tag + "_model.pth") Logger().info("Results saved at: \n %s" % path) + + +def save_selection(selection: List[int], file_path: str) -> None: + """Save the list of event numbers to a CSV file. + + Args: + selection: List of event ids. + file_path: File path to save the selection. + """ + assert isinstance( + selection, list + ), "Selection should be a list of integers." + with open(file_path, "w") as f: + f.write(",".join(map(str, selection))) + f.write("\n") diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py new file mode 100644 index 000000000..9f8a1b745 --- /dev/null +++ b/tests/data/test_datamodule.py @@ -0,0 +1,334 @@ +"""Unit tests for DataModule.""" + +from copy import deepcopy +import os +from typing import List, Any, Dict, Tuple +import pandas as pd +import sqlite3 +import pytest +from torch.utils.data import SequentialSampler + +from graphnet.constants import EXAMPLE_DATA_DIR +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.data.dataset import SQLiteDataset, ParquetDataset +from graphnet.data.datamodule import GraphNeTDataModule +from graphnet.models.detector import IceCubeDeepCore +from graphnet.models.graphs import KNNGraph +from graphnet.models.graphs.nodes import NodesAsPulses +from graphnet.training.utils import save_selection + + +def extract_all_events_ids( + file_path: str, dataset_kwargs: Dict[str, Any] +) -> List[int]: + """Extract all available event ids.""" + if file_path.endswith(".parquet"): + selection = pd.read_parquet(file_path)["event_id"].to_numpy().tolist() + elif file_path.endswith(".db"): + with sqlite3.connect(file_path) as conn: + query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}' + selection = ( + pd.read_sql(query, conn)["event_no"].to_numpy().tolist() + ) + else: + raise AssertionError( + f"File extension not accepted: {file_path.split('.')[-1]}" + ) + return selection + + +@pytest.fixture +def dataset_ref(request: pytest.FixtureRequest) -> pytest.FixtureRequest: + """Return the dataset reference.""" + return request.param + + +@pytest.fixture +def dataset_setup(dataset_ref: pytest.FixtureRequest) -> tuple: + """Set up the dataset for testing. + + Args: + dataset_ref: The dataset reference. + + Returns: + A tuple with the dataset reference, dataset kwargs, and dataloader kwargs. + """ + # Grab public dataset paths + data_path = ( + f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db" + if dataset_ref is SQLiteDataset + else f"{EXAMPLE_DATA_DIR}/parquet/prometheus/prometheus-events.parquet" + ) + + # Setup basic inputs; can be altered by individual tests + graph_definition = KNNGraph( + detector=IceCubeDeepCore(), + node_definition=NodesAsPulses(), + nb_nearest_neighbours=8, + input_feature_names=FEATURES.DEEPCORE, + ) + + dataset_kwargs = { + "truth_table": "mc_truth", + "pulsemaps": "total", + "truth": TRUTH.PROMETHEUS, + "features": FEATURES.PROMETHEUS, + "path": data_path, + "graph_definition": graph_definition, + } + + dataloader_kwargs = {"batch_size": 2, "num_workers": 1} + + return dataset_ref, dataset_kwargs, dataloader_kwargs + + +@pytest.fixture +def selection() -> List[int]: + """Return a selection.""" + return [1, 2, 3, 4, 5] + + +@pytest.fixture +def file_path(tmpdir: str) -> str: + """Return a file path.""" + return os.path.join(tmpdir, "selection.csv") + + +def test_save_selection(selection: List[int], file_path: str) -> None: + """Test `save_selection` function.""" + save_selection(selection, file_path) + + assert os.path.exists(file_path) + + with open(file_path, "r") as f: + content = f.read() + assert content.strip() == "1,2,3,4,5" + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_without_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Verify GraphNeTDataModule behavior when no test selection is provided. + + Args: + dataset_setup: Tuple with dataset reference, dataset arguments, and dataloader arguments. + + Raises: + Exception: If the test dataloader is accessed without providing a test selection. + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader + + with pytest.raises(Exception): + # should fail because we provided no test selection + test_dataloader = dm.test_dataloader # noqa + # validation loader should have shuffle = False by default + assert isinstance(val_dataloader.sampler, SequentialSampler) + # Should have identical batch_size + assert val_dataloader.batch_size != train_dataloader.batch_size + # Training dataloader should contain more batches + assert len(train_dataloader) > len(val_dataloader) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_single_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that selection functionality of DataModule behaves as expected. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset arguments, and dataloader arguments. + + Returns: + None + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + # extract all events + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + test_selection = selection[0:10] + train_val_selection = selection[10:] + + # Only training_dataloader args + # Default values should be assigned to validation dataloader + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=train_val_selection, + test_selection=test_selection, + ) + + train_dataloader = dm.train_dataloader + val_dataloader = dm.val_dataloader + test_dataloader = dm.test_dataloader + + # Check that the training and validation dataloader contains + # the same number of events as was given in the selection. + assert len(train_dataloader.dataset) + len(val_dataloader.dataset) == len(train_val_selection) # type: ignore + # Check that the number of events in the test dataset is equal to the + # number of events given in the selection. + assert len(test_dataloader.dataset) == len(test_selection) # type: ignore + # Training dataloader should have more batches + assert len(train_dataloader) > len(val_dataloader) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_dataloader_args( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test that arguments to dataloaders are propagated correctly. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + val_dataloader_kwargs = deepcopy(dataloader_kwargs) + test_dataloader_kwargs = deepcopy(dataloader_kwargs) + + # Setting batch sizes to different values + val_dataloader_kwargs["batch_size"] = 1 + test_dataloader_kwargs["batch_size"] = 2 + dataloader_kwargs["batch_size"] = 3 + + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + validation_dataloader_kwargs=val_dataloader_kwargs, + test_dataloader_kwargs=test_dataloader_kwargs, + ) + + # Check that the resulting dataloaders have the right batch sizes + assert dm.train_dataloader.batch_size == dataloader_kwargs["batch_size"] + assert dm.val_dataloader.batch_size == val_dataloader_kwargs["batch_size"] + assert ( + dm.test_dataloader.batch_size == test_dataloader_kwargs["batch_size"] + ) + + +@pytest.mark.parametrize( + "dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True +) +def test_ensemble_dataset_without_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality without selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # Make dataloaders from single dataset + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + dm_single = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=deepcopy(dataset_kwargs), + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # Create dataloaders from multiple datasets + dm_ensemble = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + ) + + # Test that the ensemble dataloaders contain more batches + assert len(dm_single.train_dataloader) < len(dm_ensemble.train_dataloader) + assert len(dm_single.val_dataloader) < len(dm_ensemble.val_dataloader) + + +@pytest.mark.parametrize("dataset_ref", [SQLiteDataset, ParquetDataset]) +def test_ensemble_dataset_with_selections( + dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]] +) -> None: + """Test ensemble dataset functionality with selections. + + Args: + dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference, + dataset keyword arguments, and dataloader keyword arguments. + + Returns: + None + """ + # extract all events + dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup + file_path = dataset_kwargs["path"] + selection = extract_all_events_ids( + file_path=file_path, dataset_kwargs=dataset_kwargs + ) + + # Copy dataset path twice; mimic ensemble dataset behavior + ensemble_dataset_kwargs = deepcopy(dataset_kwargs) + dataset_path = ensemble_dataset_kwargs["path"] + ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path] + + # pass two datasets but only one selection; should fail: + with pytest.raises(Exception): + _ = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=selection, + ) + + # Pass two datasets and two selections; should work: + selection_1 = selection[0:20] + selection_2 = selection[0:10] + dm = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection_1, selection_2], + ) + n_events_in_dataloaders = len(dm.train_dataloader.dataset) + len(dm.val_dataloader.dataset) # type: ignore + + # Check that the number of events in train/val match + assert n_events_in_dataloaders == len(selection_1) + len(selection_2) + + # Pass two datasets, two selections and two test selections; should work + dm2 = GraphNeTDataModule( + dataset_reference=dataset_ref, + dataset_args=ensemble_dataset_kwargs, + train_dataloader_kwargs=dataloader_kwargs, + selection=[selection, selection], + test_selection=[selection_1, selection_2], + ) + + # Check that the number of events in test dataloaders are correct. + n_events_in_test_dataloaders = len(dm2.test_dataloader.dataset) # type: ignore + assert n_events_in_test_dataloaders == len(selection_1) + len(selection_2)