diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py new file mode 100644 index 00000000..6e40c775 --- /dev/null +++ b/src/downstream_tasks/subclass_identification.py @@ -0,0 +1,133 @@ +import os +from typing import Callable, Dict, Optional, Union + +import lightning as L +import torch + +from src.explainers.functional import ExplainFunc +from src.explainers.wrappers.captum_influence import captum_similarity_explain +from src.metrics.localization.identical_class import IdenticalClass +from src.utils.datasets.transformed.label_grouping import ( + ClassToGroupLiterals, + LabelGroupingDataset, +) +from src.utils.training.trainer import BaseTrainer, Trainer + + +class SubclassIdentification: + def __init__( + self, + model: torch.nn.Module, + optimizer: Callable, + lr: float, + criterion: torch.nn.modules.loss._Loss, + scheduler: Optional[Callable] = None, + optimizer_kwargs: Optional[dict] = None, + scheduler_kwargs: Optional[dict] = None, + device: str = "cpu", + *args, + **kwargs, + ): + self.device = device + self.trainer: Optional[BaseTrainer] = Trainer.from_arguments( + model=model, + optimizer=optimizer, + lr=lr, + scheduler=scheduler, + criterion=criterion, + optimizer_kwargs=optimizer_kwargs, + scheduler_kwargs=scheduler_kwargs, + ) + + @classmethod + def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassIdentification, obj).__init__() + obj.device = device + obj.trainer = Trainer.from_lightning_module(model, pl_module) + return obj + + @classmethod + def from_trainer(cls, trainer: BaseTrainer, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassIdentification, obj).__init__() + if isinstance(trainer, BaseTrainer): + obj.trainer = trainer + obj.device = device + else: + raise ValueError("trainer must be an instance of BaseTrainer") + return obj + + def evaluate( + self, + train_dataset: torch.utils.data.Dataset, + val_dataset: Optional[torch.utils.data.Dataset] = None, + n_classes: int = 10, + n_groups: int = 2, + class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", + explain_fn: ExplainFunc = captum_similarity_explain, + explain_kwargs: Optional[dict] = None, + trainer_kwargs: Optional[dict] = None, + cache_dir: str = "./cache", + model_id: str = "default_model_id", + run_id: str = "default_subclass_identification", + seed: int = 27, + batch_size: int = 8, + device: str = "cpu", + *args, + **kwargs, + ): + if self.trainer is None: + raise ValueError( + "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " + "init_trainer_from_train_arguments" + ) + if explain_kwargs is None: + explain_kwargs = {} + if trainer_kwargs is None: + trainer_kwargs = {} + + grouped_dataset = LabelGroupingDataset( + dataset=train_dataset, + n_classes=n_classes, + n_groups=n_groups, + class_to_group=class_to_group, + seed=seed, + ) + grouped_train_loader = torch.utils.data.DataLoader(grouped_dataset, batch_size=batch_size) + original_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + if val_dataset: + grouped_val_dataset = LabelGroupingDataset( + dataset=train_dataset, + n_classes=n_classes, + n_groups=n_groups, + class_to_group=grouped_dataset.class_to_group, + seed=seed, + ) + val_loader: Optional[torch.utils.data.DataLoader] = torch.utils.data.DataLoader( + grouped_val_dataset, batch_size=batch_size + ) + else: + val_loader = None + + model = self.trainer.fit( + train_loader=grouped_train_loader, + val_loader=val_loader, + trainer_kwargs=trainer_kwargs, + ) + metric = IdenticalClass(model=model, train_dataset=train_dataset, device="cpu") + + for input, labels in original_train_loader: + input, labels = input.to(device), labels.to(device) + explanations = explain_fn( + model=model, + model_id=model_id, + cache_dir=os.path.join(cache_dir, run_id), + train_dataset=train_dataset, + test_tensor=input, + device=device, + **explain_kwargs, + ) + metric.update(labels, explanations) + + return metric.compute() diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index b098dc23..3caa52a2 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -27,7 +27,7 @@ def __init__( seed: int = 42, model_id: str = "0", cache_dir: str = "./cache", - device: str = "cpu" if torch.cuda.is_available() else "cuda", + device: str = "cpu", *args, **kwargs, ): @@ -104,7 +104,7 @@ def explain_update( corrs = self.corr_measure(explanations, rand_explanations) self.results["scores"].append(corrs) - def compute(self): + def compute(self) -> torch.Tensor: return torch.cat(self.results["scores"]).mean() def reset(self): @@ -112,7 +112,7 @@ def reset(self): self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(self.model) - def state_dict(self): + def state_dict(self) -> Dict: state_dict = { "results_dict": self.results, "rnd_model": self.model.state_dict(), @@ -132,7 +132,7 @@ def load_state_dict(self, state_dict: dict): # self.explain_fn = state_dict["explain_fn"] # self.generator.set_state(state_dict["generator_state"]) - def _randomize_model(self, model: torch.nn.Module): + def _randomize_model(self, model: torch.nn.Module) -> torch.nn.Module: rand_model = copy.deepcopy(model) for name, param in list(rand_model.named_parameters()): random_param_tensor = torch.empty_like(param).normal_(generator=self.generator) diff --git a/src/utils/datasets/transformed/__init__.py b/src/utils/datasets/transformed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/datasets/transformed/base.py b/src/utils/datasets/transformed/base.py new file mode 100644 index 00000000..01a64dc9 --- /dev/null +++ b/src/utils/datasets/transformed/base.py @@ -0,0 +1,68 @@ +import random +from typing import Any, Callable, Optional, Sized + +import torch +from torch.utils.data.dataset import Dataset + + +class TransformedDataset(Dataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + n_classes: int, + cache_path: str = "./cache", + cls_idx: Optional[int] = None, + # If isinstance(subset_idx,int): perturb this class with probability p, + # if isinstance(subset_idx,List[int]): perturb datapoints with these indices with probability p + p: float = 1.0, + seed: int = 42, + device: str = "cpu", + sample_fn: Optional[Callable] = None, + label_fn: Optional[Callable] = None, + ): + super().__init__() + self.dataset = dataset + self.n_classes = n_classes + self.cls_idx = cls_idx + self.cache_path = cache_path + self.p = p + if sample_fn is not None: + self.sample_fn = sample_fn + else: + self.sample_fn = self._identity + if label_fn is not None: + self.label_fn = label_fn + else: + self.label_fn = self._identity + + self.seed = seed + self.rng = random.Random(seed) + self.torch_rng = torch.Generator() + self.torch_rng.manual_seed(seed) + + self.samples_to_perturb = torch.rand(len(self), generator=self.torch_rng) <= self.p + if self.cls_idx is not None: + self.samples_to_perturb *= torch.tensor( + [self.dataset[s][1] == self.cls_idx for s in range(len(self))], dtype=torch.bool + ) + + def __len__(self) -> int: + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) + + def __getitem__(self, index) -> Any: + x, y = self.dataset[index] + xx = self.sample_fn(x) + yy = self.label_fn(y) + + return (xx, yy) if index in self.samples_to_perturb else (x, y) + + def _get_original_label(self, index) -> int: + _, y = self.dataset[index] + return y + + @staticmethod + def _identity(x: Any) -> Any: + return x diff --git a/src/utils/datasets/transformed/label_grouping.py b/src/utils/datasets/transformed/label_grouping.py new file mode 100644 index 00000000..15500c20 --- /dev/null +++ b/src/utils/datasets/transformed/label_grouping.py @@ -0,0 +1,50 @@ +from typing import Dict, Literal, Union + +import torch + +from src.utils.datasets.transformed.base import TransformedDataset + +ClassToGroupLiterals = Literal["random"] + + +class LabelGroupingDataset(TransformedDataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + n_classes: int, + seed: int = 42, + device: str = "cpu", + n_groups: int = 2, + class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", + ): + + super().__init__( + dataset=dataset, + n_classes=n_classes, + seed=seed, + device=device, + p=1.0, + cls_idx=None, # apply to all datapoints with certainty + ) + self.n_classes = n_classes + self.classes = list(range(n_classes)) + self.n_groups = n_groups + self.groups = list(range(n_groups)) + + if class_to_group == "random": + # create a dictionary of class groups that assigns each class to a group + group_assignments = [self.rng.randint(0, n_groups - 1) for _ in range(n_classes)] + self.class_to_group = {} + for i in range(n_classes): + self.class_to_group[i] = group_assignments[i] + + elif isinstance(class_to_group, dict): + self._validate_class_to_group(class_to_group) + self.class_to_group = class_to_group + else: + raise ValueError(f"Invalid class_to_group value: {class_to_group}") + self.label_fn = lambda x: self.class_to_group[x] + + def _validate_class_to_group(self, class_to_group): + assert len(class_to_group) == self.n_classes + assert all([g in self.groups for g in self.class_to_group.values()]) diff --git a/src/utils/datasets/transformed/label_poisoning.py b/src/utils/datasets/transformed/label_poisoning.py new file mode 100644 index 00000000..3e0f5bfc --- /dev/null +++ b/src/utils/datasets/transformed/label_poisoning.py @@ -0,0 +1,31 @@ +from typing import Optional + +import torch + +from src.utils.datasets.transformed.base import TransformedDataset + + +class LabelPoisoningDataset(TransformedDataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + n_classes: int, + cls_idx: Optional[int] = None, + p: float = 1.0, # TODO: decide on default value vis-à-vis subset_idx + seed: int = 42, + device: str = "cpu", + ): + + super().__init__(dataset=dataset, n_classes=n_classes, seed=seed, device=device, p=p, cls_idx=cls_idx) + self.poisoned_labels = {i: self._poison(self.dataset[i][1]) for i in range(len(self))} + + def _poison(self, original_label): + label_arr = [i for i in range(self.n_classes) if original_label != i] + label_idx = self.rng.randint(0, len(label_arr)) + return label_arr[label_idx] + + def __getitem__(self, index): + x, y = self.dataset[index] + if index in self.samples_to_perturb: + y = self.poisoned_labels[index] + return x, y diff --git a/src/utils/datasets/transformed/sample.py b/src/utils/datasets/transformed/sample.py new file mode 100644 index 00000000..d52c3720 --- /dev/null +++ b/src/utils/datasets/transformed/sample.py @@ -0,0 +1,30 @@ +from typing import Callable, Literal, Optional + +import torch + +from src.utils.datasets.transformed.base import TransformedDataset + +ClassToGroupLiterals = Literal["random"] + + +class SampleTransformationDataset(TransformedDataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + n_classes: int, + cls_idx: Optional[int] = None, + p: float = 1.0, + seed: int = 42, + device: str = "cpu", + sample_fn: Optional[Callable] = None, + ): + + super().__init__( + dataset=dataset, + n_classes=n_classes, + seed=seed, + device=device, + p=p, + cls_idx=cls_idx, + sample_fn=sample_fn, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 1693ea73..6d6286af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import TensorDataset -from src.utils.datasets.group_label_dataset import GroupLabelDataset +from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset from tests.models import LeNet MNIST_IMAGE_SIZE = 28 @@ -84,7 +84,7 @@ def load_grouped_mnist_dataset(): )[:MINI_BATCH_SIZE] y_batch = np.loadtxt("tests/assets/mnist_test_suite_1/mnist_y").astype(int)[:MINI_BATCH_SIZE] dataset = TestTensorDataset(torch.tensor(x_batch).float(), torch.tensor(y_batch).long()) - return GroupLabelDataset( + return LabelGroupingDataset( dataset, n_classes=10, n_groups=2, diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 0457b421..9557825c 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -66,7 +66,7 @@ def test_randomization_metric( def test_randomization_metric_model_randomization(test_id, model, dataset, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) - metric = ModelRandomizationMetric(model=model, train_dataset=dataset, explain_fn=lambda x: x, seed=42, device="cpu") + metric = ModelRandomizationMetric(model=model, train_dataset=dataset, explain_fn=lambda *x: x, seed=42, device="cpu") rand_model = metric.rand_model for param1, param2 in zip(model.parameters(), rand_model.parameters()): assert not torch.allclose(param1.data, param2.data), "Test failed." diff --git a/tests/utils/datasets/__init__.py b/tests/utils/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/datasets/transformed/__init__.py b/tests/utils/datasets/transformed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_grouped_label_dataset.py b/tests/utils/datasets/transformed/test_label_grouping.py similarity index 69% rename from tests/utils/test_grouped_label_dataset.py rename to tests/utils/datasets/transformed/test_label_grouping.py index 94b23ff2..1a4d39e1 100644 --- a/tests/utils/test_grouped_label_dataset.py +++ b/tests/utils/datasets/transformed/test_label_grouping.py @@ -1,6 +1,6 @@ import pytest -from src.utils.datasets.group_label_dataset import GroupLabelDataset +from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset @pytest.mark.utils @@ -17,7 +17,7 @@ ), ], ) -def test_identical_subclass_metrics( +def test_label_grouping( dataset, n_classes, n_groups, @@ -28,7 +28,7 @@ def test_identical_subclass_metrics( ): dataset = request.getfixturevalue(dataset) - grouped_dataset = GroupLabelDataset( + grouped_dataset = LabelGroupingDataset( dataset=dataset, n_classes=n_classes, n_groups=n_groups, @@ -40,7 +40,7 @@ def test_identical_subclass_metrics( for i in range(len(grouped_dataset)): x, g = grouped_dataset[i] - y = grouped_dataset.get_subclass_label(i) - assertions.append((g in range(n_groups)) & (g == grouped_dataset.class_to_group[y])) + y = grouped_dataset._get_original_label(i) + assertions.append((i not in grouped_dataset.samples_to_perturb) or (g == grouped_dataset.class_to_group[y])) assert all(assertions)