Skip to content

Commit

Permalink
Refactored dataset_reference argument in GraphNeTDataModule
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Feb 12, 2024
1 parent 9f66587 commit 33fc7d0
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
"""Create dataloaders from dataset.
Args:
dataset_reference: A non-instantiated reference to the dataset class.
dataset_reference: A non-instantiated reference
to the dataset class.
dataset_args: Arguments to instantiate
graphnet.data.dataset.Dataset with.
selection: (Optional) a list of event id's used for training
Expand Down Expand Up @@ -96,7 +97,9 @@ def setup(self, stage: str) -> None:
self._test_selection is not None
or len(self._test_dataloader_kwargs) > 0
):
self._test_dataset = self._create_dataset(self._test_selection) # type: ignore
self._test_dataset = self._create_dataset(
self._test_selection # type: ignore
)
if stage == "fit" or stage == "validate":
if self._train_selection is not None:
self._train_dataset = self._create_dataset(
Expand Down Expand Up @@ -318,7 +321,9 @@ def _split_selection(
flat_selection = [selection]
elif isinstance(selection[0], list):
flat_selection = [
item for sublist in selection for item in sublist # type: ignore
item
for sublist in selection
for item in sublist # type: ignore
]
else:
flat_selection = selection # type: ignore
Expand Down

0 comments on commit 33fc7d0

Please sign in to comment.