From 41fdfa8ff7ae464116d85df16211b069921ec171 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 17 Apr 2024 10:29:10 +0200 Subject: [PATCH] update type hints --- src/graphnet/data/curated_datamodule.py | 4 ++-- src/graphnet/data/datamodule.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/curated_datamodule.py b/src/graphnet/data/curated_datamodule.py index a206783bc..63b691c9d 100644 --- a/src/graphnet/data/curated_datamodule.py +++ b/src/graphnet/data/curated_datamodule.py @@ -31,8 +31,8 @@ def __init__( features: Optional[List[str]] = None, backend: str = "parquet", train_dataloader_kwargs: Optional[Dict[str, Any]] = None, - validation_dataloader_kwargs: Optional[Dict[str, Any]] = None, - test_dataloader_kwargs: Optional[Dict[str, Any]] = None, + validation_dataloader_kwargs: Dict[str, Any] = None, + test_dataloader_kwargs: Dict[str, Any] = None, ) -> None: """Construct CuratedDataset. diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 44cf93fc9..33f31c5fe 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -26,10 +26,7 @@ def __init__( dataset_args: Dict[str, Any], selection: Optional[Union[List[int], List[List[int]]]] = None, test_selection: Optional[Union[List[int], List[List[int]]]] = None, - train_dataloader_kwargs: Dict[str, Any] = { - "batch_size": 2, - "num_workers": 1, - }, + train_dataloader_kwargs: Dict[str, Any] = None, validation_dataloader_kwargs: Dict[str, Any] = None, test_dataloader_kwargs: Dict[str, Any] = None, train_val_split: Optional[List[float]] = [0.9, 0.10], @@ -67,6 +64,9 @@ def __init__( self._train_val_split = train_val_split or [0.0] self._rng = split_seed + if train_dataloader_kwargs is None: + train_dataloader_kwargs = {"batch_size": 2, "num_workers": 1} + self._set_dataloader_kwargs( train_dataloader_kwargs, validation_dataloader_kwargs,