From 7b29e2e540713ce5bac0c644a9341dfe23b97162 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 9 Sep 2023 17:09:29 +0500 Subject: [PATCH 1/6] Update Random Number Generation with NumPy Generators --- src/graphnet/training/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 2578ff9a6..4518bf4f1 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -115,7 +115,7 @@ def make_train_validation_dataloader( ) -> Tuple[DataLoader, DataLoader]: """Construct train and test `DataLoader` instances.""" # Reproducibility - rng = np.random.RandomState(seed=seed) + rng = np.random.default_rng(seed=seed) # Checks(s) if isinstance(pulsemaps, str): From aafde7d39ccc0399af4b6fef181bbe52e13e4c3b Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 9 Sep 2023 19:50:27 +0500 Subject: [PATCH 2/6] added seed support in sqlite_dataset_perturbed.py --- .../dataset/sqlite/sqlite_dataset_perturbed.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index 755d96b82..ce52b46a7 100644 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np +from numpy.random import default_rng, Generator import torch from torch_geometric.data import Data @@ -36,6 +37,7 @@ def __init__( loss_weight_table: Optional[str] = None, loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, + seed: Optional[Union[int, Generator]] = None, ): """Construct SQLiteDatasetPerturbed. @@ -78,6 +80,10 @@ def __init__( in this case to events with no value in the corresponding table/column. That is, if no per-event loss weight table/column is provided, this value is ignored. Defaults to None. + seed: Optional seed for random number generation (int or numpy Generator). + If provided, it will be used to initialize the random number generator + for data perturbation. Defaults to None. + """ # Base class constructor super().__init__( @@ -108,6 +114,16 @@ def __init__( self._features.index(key) for key in self._perturbation_dict.keys() ] + if seed is not None: + if isinstance(seed, int): + self.rng = default_rng(seed) + elif isinstance(seed, Generator): + self.rng = seed + else: + raise ValueError("Invalid seed. Must be an int or a numpy Generator.") + else: + self.rng = default_rng() + def __getitem__(self, sequential_index: int) -> Data: """Return graph `Data` object at `index`.""" if not (0 <= sequential_index < len(self)): @@ -127,7 +143,7 @@ def _perturb_features( self, features: List[Tuple[float, ...]] ) -> List[Tuple[float, ...]]: features_array = np.array(features) - perturbed_features = np.random.normal( + perturbed_features = self.rng.normal( loc=features_array[:, self._perturbation_cols], scale=np.array( list(self._perturbation_dict.values()), dtype=np.float From ddd0e187d8f7e77df6e72bb11a2e89818b38d024 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 9 Sep 2023 22:33:13 +0500 Subject: [PATCH 3/6] refactored the code --- .../data/dataset/sqlite/sqlite_dataset_perturbed.py | 11 +++++------ src/graphnet/training/utils.py | 7 +++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index ce52b46a7..b951e6916 100644 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -37,7 +37,7 @@ def __init__( loss_weight_table: Optional[str] = None, loss_weight_column: Optional[str] = None, loss_weight_default_value: Optional[float] = None, - seed: Optional[Union[int, Generator]] = None, + seed: Optional[Union[int, Generator]] = None, ): """Construct SQLiteDatasetPerturbed. @@ -80,10 +80,7 @@ def __init__( in this case to events with no value in the corresponding table/column. That is, if no per-event loss weight table/column is provided, this value is ignored. Defaults to None. - seed: Optional seed for random number generation (int or numpy Generator). - If provided, it will be used to initialize the random number generator - for data perturbation. Defaults to None. - + seed: Optional seed for random number generation. Defaults to None. """ # Base class constructor super().__init__( @@ -120,7 +117,9 @@ def __init__( elif isinstance(seed, Generator): self.rng = seed else: - raise ValueError("Invalid seed. Must be an int or a numpy Generator.") + raise ValueError( + "Invalid seed. Must be an int or a numpy Generator." + ) else: self.rng = default_rng() diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index 4518bf4f1..a092c36ed 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -115,7 +115,6 @@ def make_train_validation_dataloader( ) -> Tuple[DataLoader, DataLoader]: """Construct train and test `DataLoader` instances.""" # Reproducibility - rng = np.random.default_rng(seed=seed) # Checks(s) if isinstance(pulsemaps, str): @@ -156,16 +155,16 @@ def make_train_validation_dataloader( {"event_no": selection, "db": database_indices} ) shuffled_df = df_for_shuffle.sample( - frac=1, replace=False, random_state=rng + frac=1, replace=False, random_state=seed ) training_df, validation_df = train_test_split( - shuffled_df, test_size=test_size, random_state=rng + shuffled_df, test_size=test_size, random_state=seed ) training_selection = training_df.values.tolist() validation_selection = validation_df.values.tolist() else: training_selection, validation_selection = train_test_split( - selection, test_size=test_size, random_state=rng + selection, test_size=test_size, random_state=seed ) # Create DataLoaders From 38b4c559d391959a13d465663543c7e1a0948cb0 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 9 Sep 2023 22:55:00 +0500 Subject: [PATCH 4/6] fixed the typo --- src/graphnet/training/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/training/utils.py b/src/graphnet/training/utils.py index a092c36ed..f7d5249f9 100644 --- a/src/graphnet/training/utils.py +++ b/src/graphnet/training/utils.py @@ -115,7 +115,7 @@ def make_train_validation_dataloader( ) -> Tuple[DataLoader, DataLoader]: """Construct train and test `DataLoader` instances.""" # Reproducibility - + rng = np.random.default_rng(seed=seed) # Checks(s) if isinstance(pulsemaps, str): pulsemaps = [pulsemaps] @@ -155,7 +155,7 @@ def make_train_validation_dataloader( {"event_no": selection, "db": database_indices} ) shuffled_df = df_for_shuffle.sample( - frac=1, replace=False, random_state=seed + frac=1, replace=False, random_state=rng ) training_df, validation_df = train_test_split( shuffled_df, test_size=test_size, random_state=seed From 37512c61b88baa48d6fbecbe650f2f350d1a0587 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sun, 10 Sep 2023 00:19:33 +0500 Subject: [PATCH 5/6] added seed in super() --- src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index b951e6916..cb3e72c81 100644 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -98,6 +98,7 @@ def __init__( loss_weight_table=loss_weight_table, loss_weight_column=loss_weight_column, loss_weight_default_value=loss_weight_default_value, + seed=seed, ) # Custom member variables From e375e4be02122a00920c40761c3d1b5d5b59841d Mon Sep 17 00:00:00 2001 From: samadpls Date: Sun, 10 Sep 2023 00:25:50 +0500 Subject: [PATCH 6/6] fixed typo --- src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py index cb3e72c81..b951e6916 100644 --- a/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py +++ b/src/graphnet/data/dataset/sqlite/sqlite_dataset_perturbed.py @@ -98,7 +98,6 @@ def __init__( loss_weight_table=loss_weight_table, loss_weight_column=loss_weight_column, loss_weight_default_value=loss_weight_default_value, - seed=seed, ) # Custom member variables