From 4012e77b188e2af81de83e2d792bd84d1b20a1d6 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Sat, 14 Sep 2024 16:27:44 +0200 Subject: [PATCH] Only infer train/val selection in DataModule if test selection is not given --- src/graphnet/data/curated_datamodule.py | 6 +++--- src/graphnet/data/datamodule.py | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/curated_datamodule.py b/src/graphnet/data/curated_datamodule.py index 942d91bfa..13566a3fb 100644 --- a/src/graphnet/data/curated_datamodule.py +++ b/src/graphnet/data/curated_datamodule.py @@ -321,7 +321,7 @@ def __init__( If instantiated in "test" or "test-no-noise" mode, already processed photons will be read from "pulses" or - "pulses-no-noise", respectively. GraphDefinition passed to the dataset + "pulses_no_noise", respectively. GraphDefinition passed to the dataset should in this case not smear charge and time variables, and should not apply any merging. @@ -351,7 +351,7 @@ def __init__( elif self._mode == "test": self._pulsemaps = ["pulses"] elif self._mode == "test-no-noise": - self._pulsemaps = ["pulses-no-noise"] + self._pulsemaps = ["pulses_no_noise"] else: raise AssertionError( "'mode' must be one of " @@ -508,7 +508,7 @@ def __init__( If instantiated in "test" or "test-no-noise" mode, already processed photons will be read from "pulses" or - "pulses-no-noise", respectively. GraphDefinition passed to the dataset + "pulses_no_noise", respectively. GraphDefinition passed to the dataset should in this case not smear charge and time variables, and should not apply any merging. diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 33f31c5fe..b0609eb06 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -192,8 +192,13 @@ def setup(self, stage: str) -> None: self._train_dataset = self._create_dataset( self._train_selection ) + else: + self._train_dataset = None + if self._val_selection is not None: self._val_dataset = self._create_dataset(self._val_selection) + else: + self._val_dataset = None return @@ -377,12 +382,12 @@ def _resolve_selections(self) -> None: self._selection ) - else: # selection is None + elif self._test_selection is None: # selection is None # If not provided, we infer it by grabbing # all event ids in the dataset. self.info( - f"{self.__class__.__name__} did not receive an" - " for `selection`. Selection will " + f"{self.__class__.__name__} did not receive an argument" + " for `selection`. Selection " "will automatically be created with a split of " f"train: {self._train_val_split[0]} and " f"validation: {self._train_val_split[1]}" @@ -391,6 +396,15 @@ def _resolve_selections(self) -> None: self._train_selection, self._val_selection, ) = self._infer_selections() # type: ignore + else: + # Only test selection given - no training / val selection inferred + self.info( + f"{self.__class__.__name__} only recieved arguments for a" + " test selection. DataLoaders for training and validation" + " will not be available." + ) + self._train_selection = None # type: ignore + self._val_selection = None # type: ignore def _split_selection( self, selection: Union[int, List[int], List[List[int]]]