Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WiP: Datasets reworked #61

Merged
merged 15 commits into from
Jun 28, 2024
Merged
2 changes: 1 addition & 1 deletion src/downstream_tasks/subclass_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from src.explainers.functional import ExplainFunc
from src.explainers.wrappers.captum_influence import captum_similarity_explain
from src.metrics.localization.identical_class import IdenticalClass
from src.utils.datasets.group_label_dataset import (
from src.utils.datasets.transformed_datasets.label_grouping import (
ClassToGroupLiterals,
GroupLabelDataset,
)
Expand Down
8 changes: 4 additions & 4 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
seed: int = 42,
model_id: str = "0",
cache_dir: str = "./cache",
device: str = "cpu" if torch.cuda.is_available() else "cuda",
device: str = "cpu",
*args,
**kwargs,
):
Expand Down Expand Up @@ -101,15 +101,15 @@ def explain_update(
corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

def compute(self):
def compute(self) -> torch.Tensor:
return torch.cat(self.results["scores"]).mean()

def reset(self):
self.results = {"scores": []}
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(self.model)

def state_dict(self):
def state_dict(self) -> Dict:
state_dict = {
"results_dict": self.results,
"rnd_model": self.model.state_dict(),
Expand All @@ -129,7 +129,7 @@ def load_state_dict(self, state_dict: dict):
# self.explain_fn = state_dict["explain_fn"]
# self.generator.set_state(state_dict["generator_state"])

def _randomize_model(self, model: torch.nn.Module):
def _randomize_model(self, model: torch.nn.Module) -> torch.nn.Module:
rand_model = copy.deepcopy(model)
for name, param in list(rand_model.named_parameters()):
random_param_tensor = torch.empty_like(param).normal_(generator=self.generator)
Expand Down
76 changes: 0 additions & 76 deletions src/utils/datasets/corrupt_label_dataset.py

This file was deleted.

53 changes: 0 additions & 53 deletions src/utils/datasets/group_label_dataset.py

This file was deleted.

32 changes: 0 additions & 32 deletions src/utils/datasets/label_transform_dataset.py

This file was deleted.

20 changes: 0 additions & 20 deletions src/utils/datasets/reduced_label_dataset.py

This file was deleted.

61 changes: 0 additions & 61 deletions src/utils/datasets/sample_transform_dataset.py

This file was deleted.

68 changes: 68 additions & 0 deletions src/utils/datasets/transformed_datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import random
from typing import Any, Callable, Optional, Sized

import torch
from torch.utils.data.dataset import Dataset


class TransformedDataset(Dataset):
@staticmethod
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
def _identity(x: Any) -> Any:
return x

def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
cache_path: str = "./cache",
cls_idx: Optional[int] = None,
# If isinstance(subset_idx,int): perturb this class with probability p,
# if isinstance(subset_idx,List[int]): perturb datapoints with these indices with probability p
p: float = 1.0,
seed: Optional[int] = None,
device: str = "cpu",
sample_fn: Optional[Callable] = None,
label_fn: Optional[Callable] = None,
):
super().__init__()
self.dataset = dataset
self.n_classes = n_classes
self.cls_idx = cls_idx
self.cache_path = cache_path
self.p = p
if sample_fn is not None:
self.sample_fn = sample_fn
else:
self.sample_fn = TransformedDataset._identity
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
if label_fn is not None:
self.label_fn = label_fn
else:
self.label_fn = TransformedDataset._identity
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
self.seed = seed
if self.seed is not None:
random.seed(self.seed)
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
self.samples_to_perturb = []
for i in range(self.__len__()):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit-picking ]list comprehensions are generally faster and nicer-looking than loops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree that it would look better for this particular case. It's just too many components to put in a line. However, if it is actually faster, then that should probably be the priority because we are looping over a whole train dataset.

I can not commit the single line version. Black (version 24.0.0) reformats it into a form that flake8 (6.0.0) (does not accept.) I am commenting it out for your hands to solve 🙏🏼

x, y = dataset[i]
perturb_sample = (self.cls_idx is None) or (y == self.cls_idx)
p_condition = (random.random() <= self.p) if self.p < 1.0 else True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the papers actually require to transform p % of the data or to do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some cases, we need to poison some random subset or we need to watermark some random subset of a class. For other cases, the child class does not have a p parameter, and passes p=1.0 to the base class

gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
perturb_sample = p_condition and perturb_sample
if perturb_sample:
self.samples_to_perturb.append(i)

def __len__(self):
if isinstance(self.dataset, Sized):
return len(self.dataset)
dl = torch.utils.data.DataLoader(self.dataset, batch_size=1)
return len(dl)

def __getitem__(self, index):
x, y = self.dataset[index]
xx = self.sample_fn(x)
yy = self.label_fn(y)

return xx, yy if index in self.samples_to_perturb else x, y

def _get_original_label(self, index):
_, y = self.dataset[index]
return y
Loading
Loading