Skip to content

Commit

Permalink
Merge pull request #695 from RasmusOrsoe/fix_code_quality
Browse files Browse the repository at this point in the history
Update Type hints in `GraphNeTDataModule`
  • Loading branch information
RasmusOrsoe authored Apr 17, 2024
2 parents 9fc3400 + 41fdfa8 commit 2248da4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/graphnet/data/curated_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
features: Optional[List[str]] = None,
backend: str = "parquet",
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
) -> None:
"""Construct CuratedDataset.
Expand Down
8 changes: 4 additions & 4 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ def __init__(
dataset_args: Dict[str, Any],
selection: Optional[Union[List[int], List[List[int]]]] = None,
test_selection: Optional[Union[List[int], List[List[int]]]] = None,
train_dataloader_kwargs: Dict[str, Any] = {
"batch_size": 2,
"num_workers": 1,
},
train_dataloader_kwargs: Dict[str, Any] = None,
validation_dataloader_kwargs: Dict[str, Any] = None,
test_dataloader_kwargs: Dict[str, Any] = None,
train_val_split: Optional[List[float]] = [0.9, 0.10],
Expand Down Expand Up @@ -67,6 +64,9 @@ def __init__(
self._train_val_split = train_val_split or [0.0]
self._rng = split_seed

if train_dataloader_kwargs is None:
train_dataloader_kwargs = {"batch_size": 2, "num_workers": 1}

self._set_dataloader_kwargs(
train_dataloader_kwargs,
validation_dataloader_kwargs,
Expand Down

0 comments on commit 2248da4

Please sign in to comment.