diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index 755d96b82..b951e6916 100644 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np +from numpy.random import default_rng, Generator import torch from torch_geometric.data import Data @@ -36,6 +37,7 @@ def __init__( loss_weight_table: Optional[str] = None, loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, + seed: Optional[Union[int, Generator]] = None, ): """Construct SQLiteDatasetPerturbed. @@ -78,6 +80,7 @@ def __init__( in this case to events with no value in the corresponding table/column. That is, if no per-event loss weight table/column is provided, this value is ignored. Defaults to None. + seed: Optional seed for random number generation. Defaults to None. """ # Base class constructor super().__init__( @@ -108,6 +111,18 @@ def __init__( self._features.index(key) for key in self._perturbation_dict.keys() ] + if seed is not None: + if isinstance(seed, int): + self.rng = default_rng(seed) + elif isinstance(seed, Generator): + self.rng = seed + else: + raise ValueError( + "Invalid seed. Must be an int or a numpy Generator." + ) + else: + self.rng = default_rng() + def __getitem__(self, sequential_index: int) -> Data: """Return graph `Data` object at `index`.""" if not (0 <= sequential_index < len(self)): @@ -127,7 +142,7 @@ def _perturb_features( self, features: List[Tuple[float, ...]] ) -> List[Tuple[float, ...]]: features_array = np.array(features) - perturbed_features = np.random.normal( + perturbed_features = self.rng.normal( loc=features_array[:, self._perturbation_cols], scale=np.array( list(self._perturbation_dict.values()), dtype=np.float diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 2578ff9a6..f7d5249f9 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -115,8 +115,7 @@ def make_train_validation_dataloader( ) -> Tuple[DataLoader, DataLoader]: """Construct train and test `DataLoader` instances.""" # Reproducibility - rng = np.random.RandomState(seed=seed) - + rng = np.random.default_rng(seed=seed) # Checks(s) if isinstance(pulsemaps, str): pulsemaps = [pulsemaps] @@ -159,13 +158,13 @@ def make_train_validation_dataloader( frac=1, replace=False, random_state=rng ) training_df, validation_df = train_test_split( - shuffled_df, test_size=test_size, random_state=rng + shuffled_df, test_size=test_size, random_state=seed ) training_selection = training_df.values.tolist() validation_selection = validation_df.values.tolist() else: training_selection, validation_selection = train_test_split( - selection, test_size=test_size, random_state=rng + selection, test_size=test_size, random_state=seed ) # Create DataLoaders