Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WiP: Datasets reworked #61

Merged
merged 15 commits into from
Jun 28, 2024
Merged
133 changes: 133 additions & 0 deletions src/downstream_tasks/subclass_identification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
from typing import Callable, Dict, Optional, Union

import lightning as L
import torch

from src.explainers.functional import ExplainFunc
from src.explainers.wrappers.captum_influence import captum_similarity_explain
from src.metrics.localization.identical_class import IdenticalClass
from src.utils.datasets.transformed.label_grouping import (
ClassToGroupLiterals,
LabelGroupingDataset,
)
from src.utils.training.trainer import BaseTrainer, Trainer


class SubclassIdentification:
def __init__(
self,
model: torch.nn.Module,
optimizer: Callable,
lr: float,
criterion: torch.nn.modules.loss._Loss,
scheduler: Optional[Callable] = None,
optimizer_kwargs: Optional[dict] = None,
scheduler_kwargs: Optional[dict] = None,
device: str = "cpu",
*args,
**kwargs,
):
self.device = device
self.trainer: Optional[BaseTrainer] = Trainer.from_arguments(
model=model,
optimizer=optimizer,
lr=lr,
scheduler=scheduler,
criterion=criterion,
optimizer_kwargs=optimizer_kwargs,
scheduler_kwargs=scheduler_kwargs,
)

@classmethod
def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs):
obj = cls.__new__(cls)
super(SubclassIdentification, obj).__init__()
obj.device = device
obj.trainer = Trainer.from_lightning_module(model, pl_module)
return obj

@classmethod
def from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs):
obj = cls.__new__(cls)
super(SubclassIdentification, obj).__init__()
if isinstance(trainer, BaseTrainer):
obj.trainer = trainer
obj.device = device
else:
raise ValueError("trainer must be an instance of BaseTrainer")
return obj

def evaluate(
self,
train_dataset: torch.utils.data.Dataset,
val_dataset: Optional[torch.utils.data.Dataset] = None,
n_classes: int = 10,
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
explain_fn: ExplainFunc = captum_similarity_explain,
explain_kwargs: Optional[dict] = None,
trainer_kwargs: Optional[dict] = None,
cache_dir: str = "./cache",
model_id: str = "default_model_id",
run_id: str = "default_subclass_identification",
seed: int = 27,
batch_size: int = 8,
device: str = "cpu",
*args,
**kwargs,
):
if self.trainer is None:
raise ValueError(
"Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or "
"init_trainer_from_train_arguments"
)
if explain_kwargs is None:
explain_kwargs = {}
if trainer_kwargs is None:
trainer_kwargs = {}

grouped_dataset = LabelGroupingDataset(
dataset=train_dataset,
n_classes=n_classes,
n_groups=n_groups,
class_to_group=class_to_group,
seed=seed,
)
grouped_train_loader = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size)
original_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
if val_dataset:
grouped_val_dataset = LabelGroupingDataset(
dataset=train_dataset,
n_classes=n_classes,
n_groups=n_groups,
class_to_group=grouped_dataset.class_to_group,
seed=seed,
)
val_loader: Optional[torch.utils.data.DataLoader] = torch.utils.data.DataLoader(
grouped_val_dataset, batch_size=batch_size
)
else:
val_loader = None

model = self.trainer.fit(
train_loader=grouped_train_loader,
val_loader=val_loader,
trainer_kwargs=trainer_kwargs,
)
metric = IdenticalClass(model=model, train_dataset=train_dataset, device="cpu")

for input, labels in original_train_loader:
input, labels = input.to(device), labels.to(device)
explanations = explain_fn(
model=model,
model_id=model_id,
cache_dir=os.path.join(cache_dir, run_id),
train_dataset=train_dataset,
test_tensor=input,
device=device,
**explain_kwargs,
)
metric.update(labels, explanations)

return metric.compute()
8 changes: 4 additions & 4 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
seed: int = 42,
model_id: str = "0",
cache_dir: str = "./cache",
device: str = "cpu" if torch.cuda.is_available() else "cuda",
device: str = "cpu",
*args,
**kwargs,
):
Expand Down Expand Up @@ -104,15 +104,15 @@ def explain_update(
corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

def compute(self):
def compute(self) -> torch.Tensor:
return torch.cat(self.results["scores"]).mean()

def reset(self):
self.results = {"scores": []}
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(self.model)

def state_dict(self):
def state_dict(self) -> Dict:
state_dict = {
"results_dict": self.results,
"rnd_model": self.model.state_dict(),
Expand All @@ -132,7 +132,7 @@ def load_state_dict(self, state_dict: dict):
# self.explain_fn = state_dict["explain_fn"]
# self.generator.set_state(state_dict["generator_state"])

def _randomize_model(self, model: torch.nn.Module):
def _randomize_model(self, model: torch.nn.Module) -> torch.nn.Module:
rand_model = copy.deepcopy(model)
for name, param in list(rand_model.named_parameters()):
random_param_tensor = torch.empty_like(param).normal_(generator=self.generator)
Expand Down
Empty file.
68 changes: 68 additions & 0 deletions src/utils/datasets/transformed/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import random
from typing import Any, Callable, Optional, Sized

import torch
from torch.utils.data.dataset import Dataset


class TransformedDataset(Dataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
cache_path: str = "./cache",
cls_idx: Optional[int] = None,
# If isinstance(subset_idx,int): perturb this class with probability p,
# if isinstance(subset_idx,List[int]): perturb datapoints with these indices with probability p
p: float = 1.0,
seed: int = 42,
device: str = "cpu",
sample_fn: Optional[Callable] = None,
label_fn: Optional[Callable] = None,
):
super().__init__()
self.dataset = dataset
self.n_classes = n_classes
self.cls_idx = cls_idx
self.cache_path = cache_path
self.p = p
if sample_fn is not None:
self.sample_fn = sample_fn
else:
self.sample_fn = self._identity
if label_fn is not None:
self.label_fn = label_fn
else:
self.label_fn = self._identity

self.seed = seed
self.rng = random.Random(seed)
self.torch_rng = torch.Generator()
self.torch_rng.manual_seed(seed)

self.samples_to_perturb = torch.rand(len(self), generator=self.torch_rng) <= self.p
if self.cls_idx is not None:
self.samples_to_perturb *= torch.tensor(
[self.dataset[s][1] == self.cls_idx for s in range(len(self))], dtype=torch.bool
)

def __len__(self) -> int:
if isinstance(self.dataset, Sized):
return len(self.dataset)
dl = torch.utils.data.DataLoader(self.dataset, batch_size=1)
return len(dl)

def __getitem__(self, index) -> Any:
x, y = self.dataset[index]
xx = self.sample_fn(x)
yy = self.label_fn(y)

return (xx, yy) if index in self.samples_to_perturb else (x, y)

def _get_original_label(self, index) -> int:
_, y = self.dataset[index]
return y

@staticmethod
def _identity(x: Any) -> Any:
return x
50 changes: 50 additions & 0 deletions src/utils/datasets/transformed/label_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Dict, Literal, Union

import torch

from src.utils.datasets.transformed.base import TransformedDataset

ClassToGroupLiterals = Literal["random"]


class LabelGroupingDataset(TransformedDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
seed: int = 42,
device: str = "cpu",
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
):

super().__init__(
dataset=dataset,
n_classes=n_classes,
seed=seed,
device=device,
p=1.0,
cls_idx=None, # apply to all datapoints with certainty
)
self.n_classes = n_classes
self.classes = list(range(n_classes))
self.n_groups = n_groups
self.groups = list(range(n_groups))

if class_to_group == "random":
# create a dictionary of class groups that assigns each class to a group
group_assignments = [self.rng.randint(0, n_groups - 1) for _ in range(n_classes)]
self.class_to_group = {}
for i in range(n_classes):
self.class_to_group[i] = group_assignments[i]

elif isinstance(class_to_group, dict):
self._validate_class_to_group(class_to_group)
self.class_to_group = class_to_group
else:
raise ValueError(f"Invalid class_to_group value: {class_to_group}")
self.label_fn = lambda x: self.class_to_group[x]

def _validate_class_to_group(self, class_to_group):
assert len(class_to_group) == self.n_classes
assert all([g in self.groups for g in self.class_to_group.values()])
31 changes: 31 additions & 0 deletions src/utils/datasets/transformed/label_poisoning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

import torch

from src.utils.datasets.transformed.base import TransformedDataset


class LabelPoisoningDataset(TransformedDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
cls_idx: Optional[int] = None,
p: float = 1.0, # TODO: decide on default value vis-à-vis subset_idx
seed: int = 42,
device: str = "cpu",
):

super().__init__(dataset=dataset, n_classes=n_classes, seed=seed, device=device, p=p, cls_idx=cls_idx)
self.poisoned_labels = {i: self._poison(self.dataset[i][1]) for i in range(len(self))}

def _poison(self, original_label):
label_arr = [i for i in range(self.n_classes) if original_label != i]
label_idx = self.rng.randint(0, len(label_arr))
return label_arr[label_idx]

def __getitem__(self, index):
x, y = self.dataset[index]
if index in self.samples_to_perturb:
y = self.poisoned_labels[index]
return x, y
30 changes: 30 additions & 0 deletions src/utils/datasets/transformed/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable, Literal, Optional

import torch

from src.utils.datasets.transformed.base import TransformedDataset

ClassToGroupLiterals = Literal["random"]


class SampleTransformationDataset(TransformedDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
n_classes: int,
cls_idx: Optional[int] = None,
p: float = 1.0,
seed: int = 42,
device: str = "cpu",
sample_fn: Optional[Callable] = None,
):

super().__init__(
dataset=dataset,
n_classes=n_classes,
seed=seed,
device=device,
p=p,
cls_idx=cls_idx,
sample_fn=sample_fn,
)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.utils.data import TensorDataset

from src.utils.datasets.group_label_dataset import GroupLabelDataset
from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset
from tests.models import LeNet

MNIST_IMAGE_SIZE = 28
Expand Down Expand Up @@ -84,7 +84,7 @@ def load_grouped_mnist_dataset():
)[:MINI_BATCH_SIZE]
y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE]
dataset = TestTensorDataset(torch.tensor(x_batch).float(), torch.tensor(y_batch).long())
return GroupLabelDataset(
return LabelGroupingDataset(
dataset,
n_classes=10,
n_groups=2,
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_randomization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_randomization_metric(
def test_randomization_metric_model_randomization(test_id, model, dataset, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
metric = ModelRandomizationMetric(model=model, train_dataset=dataset, explain_fn=lambda x: x, seed=42, device="cpu")
metric = ModelRandomizationMetric(model=model, train_dataset=dataset, explain_fn=lambda *x: x, seed=42, device="cpu")
rand_model = metric.rand_model
for param1, param2 in zip(model.parameters(), rand_model.parameters()):
assert not torch.allclose(param1.data, param2.data), "Test failed."
Empty file.
Empty file.
Loading
Loading