Skip to content

Commit

Permalink
switch to using self.rng and non-optional seed
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 28, 2024
1 parent a2ad1e9 commit 4ec1522
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
26 changes: 13 additions & 13 deletions src/utils/datasets/transformed_datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@


class TransformedDataset(Dataset):
@staticmethod
def _identity(x: Any) -> Any:
return x

def __init__(
self,
dataset: torch.utils.data.Dataset,
Expand All @@ -19,7 +15,7 @@ def __init__(
# If isinstance(subset_idx,int): perturb this class with probability p,
# if isinstance(subset_idx,List[int]): perturb datapoints with these indices with probability p
p: float = 1.0,
seed: Optional[int] = None,
seed: int = 42,
device: str = "cpu",
sample_fn: Optional[Callable] = None,
label_fn: Optional[Callable] = None,
Expand All @@ -33,36 +29,40 @@ def __init__(
if sample_fn is not None:
self.sample_fn = sample_fn
else:
self.sample_fn = TransformedDataset._identity
self.sample_fn = self._identity
if label_fn is not None:
self.label_fn = label_fn
else:
self.label_fn = TransformedDataset._identity
self.label_fn = self._identity
self.seed = seed
if self.seed is not None:
random.seed(self.seed)
self.rng = random.Random()
self.rng.seed(self.seed)
self.samples_to_perturb = []
for i in range(self.__len__()):
x, y = dataset[i]
perturb_sample = (self.cls_idx is None) or (y == self.cls_idx)
p_condition = (random.random() <= self.p) if self.p < 1.0 else True
p_condition = (self.rng.random() <= self.p) if self.p < 1.0 else True
perturb_sample = p_condition and perturb_sample
if perturb_sample:
self.samples_to_perturb.append(i)

def __len__(self):
def __len__(self) -> int:
if isinstance(self.dataset, Sized):
return len(self.dataset)
dl = torch.utils.data.DataLoader(self.dataset, batch_size=1)
return len(dl)

def __getitem__(self, index):
def __getitem__(self, index) -> Any:
x, y = self.dataset[index]
xx = self.sample_fn(x)
yy = self.label_fn(y)

return xx, yy if index in self.samples_to_perturb else x, y

def _get_original_label(self, index):
def _get_original_label(self, index) -> int:
_, y = self.dataset[index]
return y

@staticmethod
def _identity(x: Any) -> Any:
return x
7 changes: 3 additions & 4 deletions src/utils/datasets/transformed_datasets/label_grouping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from typing import Dict, Literal, Optional, Union, cast
from typing import Dict, Literal, Union

import torch

Expand All @@ -13,7 +12,7 @@ def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
seed: Optional[int] = None,
seed: int = 42,
device: str = "cpu",
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
Expand All @@ -33,7 +32,7 @@ def __init__(
self.groups = list(range(n_groups))
if class_to_group == "random":
# create a dictionary of class groups that assigns each class to a group
group_assignments = [random.randint(0, n_groups - 1) for _ in range(n_classes)]
group_assignments = [self.rng.randint(0, n_groups - 1) for _ in range(n_classes)]
self.class_to_group = {}
for i in range(n_classes):
self.class_to_group[i] = group_assignments[i]
Expand Down
5 changes: 2 additions & 3 deletions src/utils/datasets/transformed_datasets/label_poisoning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import Optional

import torch
Expand All @@ -13,7 +12,7 @@ def __init__(
n_classes: int,
cls_idx: Optional[int] = None,
p: float = 1.0, # TODO: decide on default value vis-à-vis subset_idx
seed: Optional[int] = None,
seed: int = 42,
device: str = "cpu",
):

Expand All @@ -25,7 +24,7 @@ def __init__(

def _poison(self, original_label):
label_arr = [i for i in range(self.n_classes) if original_label != i]
label_idx = random.randint(0, len(label_arr))
label_idx = self.rng.randint(0, len(label_arr))
return label_arr[label_idx]

def __getitem__(self, index):
Expand Down

0 comments on commit 4ec1522

Please sign in to comment.