From 4ec1522a8471a39a208c5978323443d37d1327d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 28 Jun 2024 17:18:39 +0200 Subject: [PATCH] switch to using self.rng and non-optional seed --- .../datasets/transformed_datasets/base.py | 26 +++++++++---------- .../transformed_datasets/label_grouping.py | 7 +++-- .../transformed_datasets/label_poisoning.py | 5 ++-- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/utils/datasets/transformed_datasets/base.py b/src/utils/datasets/transformed_datasets/base.py index 9c53c96a..dea81e3e 100644 --- a/src/utils/datasets/transformed_datasets/base.py +++ b/src/utils/datasets/transformed_datasets/base.py @@ -6,10 +6,6 @@ class TransformedDataset(Dataset): - @staticmethod - def _identity(x: Any) -> Any: - return x - def __init__( self, dataset: torch.utils.data.Dataset, @@ -19,7 +15,7 @@ def __init__( # If isinstance(subset_idx,int): perturb this class with probability p, # if isinstance(subset_idx,List[int]): perturb datapoints with these indices with probability p p: float = 1.0, - seed: Optional[int] = None, + seed: int = 42, device: str = "cpu", sample_fn: Optional[Callable] = None, label_fn: Optional[Callable] = None, @@ -33,36 +29,40 @@ def __init__( if sample_fn is not None: self.sample_fn = sample_fn else: - self.sample_fn = TransformedDataset._identity + self.sample_fn = self._identity if label_fn is not None: self.label_fn = label_fn else: - self.label_fn = TransformedDataset._identity + self.label_fn = self._identity self.seed = seed - if self.seed is not None: - random.seed(self.seed) + self.rng = random.Random() + self.rng.seed(self.seed) self.samples_to_perturb = [] for i in range(self.__len__()): x, y = dataset[i] perturb_sample = (self.cls_idx is None) or (y == self.cls_idx) - p_condition = (random.random() <= self.p) if self.p < 1.0 else True + p_condition = (self.rng.random() <= self.p) if self.p < 1.0 else True perturb_sample = p_condition and perturb_sample if perturb_sample: self.samples_to_perturb.append(i) - def __len__(self): + def __len__(self) -> int: if isinstance(self.dataset, Sized): return len(self.dataset) dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) return len(dl) - def __getitem__(self, index): + def __getitem__(self, index) -> Any: x, y = self.dataset[index] xx = self.sample_fn(x) yy = self.label_fn(y) return xx, yy if index in self.samples_to_perturb else x, y - def _get_original_label(self, index): + def _get_original_label(self, index) -> int: _, y = self.dataset[index] return y + + @staticmethod + def _identity(x: Any) -> Any: + return x diff --git a/src/utils/datasets/transformed_datasets/label_grouping.py b/src/utils/datasets/transformed_datasets/label_grouping.py index d02b9f5c..79a3a8b5 100644 --- a/src/utils/datasets/transformed_datasets/label_grouping.py +++ b/src/utils/datasets/transformed_datasets/label_grouping.py @@ -1,5 +1,4 @@ -import random -from typing import Dict, Literal, Optional, Union, cast +from typing import Dict, Literal, Union import torch @@ -13,7 +12,7 @@ def __init__( self, dataset: torch.utils.data.Dataset, n_classes: int, - seed: Optional[int] = None, + seed: int = 42, device: str = "cpu", n_groups: int = 2, class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", @@ -33,7 +32,7 @@ def __init__( self.groups = list(range(n_groups)) if class_to_group == "random": # create a dictionary of class groups that assigns each class to a group - group_assignments = [random.randint(0, n_groups - 1) for _ in range(n_classes)] + group_assignments = [self.rng.randint(0, n_groups - 1) for _ in range(n_classes)] self.class_to_group = {} for i in range(n_classes): self.class_to_group[i] = group_assignments[i] diff --git a/src/utils/datasets/transformed_datasets/label_poisoning.py b/src/utils/datasets/transformed_datasets/label_poisoning.py index 29e57a99..220eee3e 100644 --- a/src/utils/datasets/transformed_datasets/label_poisoning.py +++ b/src/utils/datasets/transformed_datasets/label_poisoning.py @@ -1,4 +1,3 @@ -import random from typing import Optional import torch @@ -13,7 +12,7 @@ def __init__( n_classes: int, cls_idx: Optional[int] = None, p: float = 1.0, # TODO: decide on default value vis-à-vis subset_idx - seed: Optional[int] = None, + seed: int = 42, device: str = "cpu", ): @@ -25,7 +24,7 @@ def __init__( def _poison(self, original_label): label_arr = [i for i in range(self.n_classes) if original_label != i] - label_idx = random.randint(0, len(label_arr)) + label_idx = self.rng.randint(0, len(label_arr)) return label_arr[label_idx] def __getitem__(self, index):