From 70f91ebd3caa837e975feabd12935c5b37097fc7 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Mon, 17 Jun 2024 11:42:31 +0200 Subject: [PATCH 1/5] update trainer base --- .gitignore | 7 +++++ src/downstream_tasks/__init__.py | 0 .../subclass_identification.py | 28 ++++++++++--------- src/utils/training/trainer.py | 27 +++++++++++------- .../test_subclass_identification.py | 3 +- tests/utils/test_training.py | 2 +- 6 files changed, 41 insertions(+), 26 deletions(-) create mode 100644 src/downstream_tasks/__init__.py diff --git a/.gitignore b/.gitignore index 476774e1..b10f0f22 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,10 @@ coverage.xml .pytest_cache/ cover/ /scratch.py + +# Lightning +lightning_logs/ +checkpoints/ + +# data_attribution_evaluation +cache/ diff --git a/src/downstream_tasks/__init__.py b/src/downstream_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py index 2d9a3ac7..2a7244c8 100644 --- a/src/downstream_tasks/subclass_identification.py +++ b/src/downstream_tasks/subclass_identification.py @@ -1,6 +1,7 @@ import os from typing import Callable, Dict, Optional, Union +import lightning as L import torch from metrics.localization.identical_class import IdenticalClass @@ -13,26 +14,27 @@ class SubclassIdentification: - def __init__(self, device: str = "cpu", *args, **kwargs): - self.device = device - self.trainer: Optional[BaseTrainer] = None - - def init_trainer_from_lightning_module(self, pl_module): - trainer = Trainer() - trainer.from_lightning_module(pl_module) - self.trainer = trainer - - def init_trainer_from_train_arguments( + def __init__( self, model: torch.nn.Module, optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, optimizer_kwargs: Optional[dict] = None, + device: str = "cpu", + *args, + **kwargs, ): - trainer = Trainer() - trainer.from_train_arguments(model, optimizer, lr, criterion, optimizer_kwargs) - self.trainer = trainer + self.device = device + self.trainer: Optional[BaseTrainer] = Trainer.from_arguments(model, optimizer, lr, criterion, optimizer_kwargs) + + @classmethod + def from_lightning_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): + obj = cls.__new__(cls) + super(SubclassIdentification, obj).__init__() + obj.device = device + obj.trainer = Trainer.from_lightning_module(model, pl_module) + return obj def evaluate( self, diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index b13c51ad..e56b289a 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -23,30 +23,37 @@ def fit( class Trainer(BaseTrainer): def __init__(self): - pass + self.model: Optional[torch.nn.Module] = None + self.module: Optional[L.LightningModule] = None - def from_train_arguments( - self, + @classmethod + def from_arguments( + cls, model: torch.nn.Module, optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, optimizer_kwargs: Optional[dict] = None, ): - self.model = model + obj = cls.__new__(cls) + super(Trainer, obj).__init__() + obj.model = model if optimizer_kwargs is None: optimizer_kwargs = {} - self.module = BasicLightningModule(model, optimizer, lr, criterion, optimizer_kwargs) - return self + obj.module = BasicLightningModule(model, optimizer, lr, criterion, optimizer_kwargs) + return obj + @classmethod def from_lightning_module( - self, + cls, model: torch.nn.Module, pl_module: L.LightningModule, ): - self.model = model - self.module = pl_module - return self + obj = cls.__new__(cls) + super(Trainer, obj).__init__() + obj.model = model + obj.module = pl_module + return obj def fit( self, diff --git a/tests/downstream_tasks/test_subclass_identification.py b/tests/downstream_tasks/test_subclass_identification.py index 7994003d..7cbed652 100644 --- a/tests/downstream_tasks/test_subclass_identification.py +++ b/tests/downstream_tasks/test_subclass_identification.py @@ -50,8 +50,7 @@ def test_identical_subclass_metrics( test_labels = request.getfixturevalue(test_labels) dataset = request.getfixturevalue(dataset) - dst_eval = SubclassIdentification() - dst_eval.init_trainer_from_train_arguments( + dst_eval = SubclassIdentification( model=model, optimizer=optimizer, lr=lr, diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index 03723fa9..038a7570 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -43,7 +43,7 @@ def test_easy_trainer( criterion = request.getfixturevalue(criterion) old_model = copy.deepcopy(model) trainer = Trainer() - model = trainer.from_train_arguments( + model = trainer.from_arguments( model=model, optimizer=optimizer, lr=lr, From 6dffcccf7eb79398a95dcb2ed6f885831a6525b0 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Mon, 17 Jun 2024 15:39:51 +0200 Subject: [PATCH 2/5] fix test self influence test --- tests/explainers/aggregators/test_self_influence.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/explainers/aggregators/test_self_influence.py b/tests/explainers/aggregators/test_self_influence.py index b43b9dae..0405693c 100644 --- a/tests/explainers/aggregators/test_self_influence.py +++ b/tests/explainers/aggregators/test_self_influence.py @@ -1,3 +1,5 @@ +import os +import shutil from collections import OrderedDict import pytest @@ -29,10 +31,14 @@ def test_self_influence_ranking(test_id, explain_kwargs, request): self_influence_rank = get_self_influence_ranking( model=model, model_id="0", - cache_dir="temp_captum", + cache_dir="./test_cache", training_data=rand_dataset, explain_fn=explain, explain_fn_kwargs=explain_kwargs, ) + # remove cache directory if it exists + if os.path.exists("./test_cache"): + shutil.rmtree("./test_cache") + assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort()) From f73f19527a1db2655301ca42814a15114d24e036 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Mon, 17 Jun 2024 16:27:38 +0200 Subject: [PATCH 3/5] add scheduler to training + after-test clean-up --- Makefile | 3 ++- src/utils/training/base_pl_module.py | 9 +++++++++ src/utils/training/trainer.py | 12 +++++++++++- .../test_subclass_identification.py | 9 ++++++++- tests/metrics/test_randomization_metrics.py | 9 +++++++++ tests/utils/test_explain_wrapper.py | 16 ++++++++++------ 6 files changed, 49 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 56b633f1..56e45bd3 100644 --- a/Makefile +++ b/Makefile @@ -12,5 +12,6 @@ style: find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf find . | grep -E ".pytest_cache" | xargs rm -rf find . | grep -E ".mypy_cache" | xargs rm -rf - find . | grep -E ".ipynb_checkpoints" | xargs rm -rf + find . | grep -E ".checkpoints" | xargs rm -rf + find . | grep -E ".lightning_logs" | xargs rm -rf find . -name '*~' -exec rm -f {} + diff --git a/src/utils/training/base_pl_module.py b/src/utils/training/base_pl_module.py index 9fe52e7a..98be8d92 100644 --- a/src/utils/training/base_pl_module.py +++ b/src/utils/training/base_pl_module.py @@ -11,7 +11,9 @@ def __init__( optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, + scheduler: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, + scheduler_kwargs: Optional[dict] = None, *args, **kwargs, ): @@ -22,6 +24,8 @@ def __init__( self.lr = lr self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} self.criterion = criterion + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs if scheduler_kwargs is not None else {} def forward(self, inputs): return self.model(inputs) @@ -44,4 +48,9 @@ def configure_optimizers(self): optimizer = self.optimizer(self.model.parameters(), lr=self.lr, **self.optimizer_kwargs) if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError("optimizer must be an instance of torch.optim.Optimizer") + if self.scheduler is not None: + scheduler = self.scheduler(optimizer, **self.scheduler_kwargs) + if not isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler): + raise ValueError("scheduler must be an instance of torch.optim.lr_scheduler._LRScheduler") + return {"optimizer": optimizer, "lr_scheduler": scheduler} return optimizer diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index e56b289a..508e7271 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -33,14 +33,24 @@ def from_arguments( optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, + scheduler: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, + scheduler_kwargs: Optional[dict] = None, ): obj = cls.__new__(cls) super(Trainer, obj).__init__() obj.model = model if optimizer_kwargs is None: optimizer_kwargs = {} - obj.module = BasicLightningModule(model, optimizer, lr, criterion, optimizer_kwargs) + obj.module = BasicLightningModule( + model=model, + optimizer=optimizer, + lr=lr, + criterion=criterion, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + ) return obj @classmethod diff --git a/tests/downstream_tasks/test_subclass_identification.py b/tests/downstream_tasks/test_subclass_identification.py index 7cbed652..68a2484a 100644 --- a/tests/downstream_tasks/test_subclass_identification.py +++ b/tests/downstream_tasks/test_subclass_identification.py @@ -1,3 +1,6 @@ +import os +import shutil + import pytest from downstream_tasks.subclass_identification import SubclassIdentification @@ -65,11 +68,15 @@ def test_identical_subclass_metrics( explain_fn=explain, explain_kwargs=explain_kwargs, trainer_kwargs={"max_epochs": max_epochs}, - cache_dir="./cache", + cache_dir="./test_cache", model_id="default_model_id", run_id="default_subclass_identification", seed=seed, batch_size=batch_size, device="cpu", ) + + # remove cache directory if it exists + if os.path.exists("./test_cache"): + shutil.rmtree("./test_cache") assert score == expected_score diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index b9909e9a..5524771c 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,3 +1,6 @@ +import os +import shutil + import pytest import torch @@ -41,12 +44,18 @@ def test_randomization_metric( train_dataset=dataset, explain_fn=explain, explain_fn_kwargs={**explain_kwargs, "layer": "fc_2"}, + cache_dir="./test_cache", correlation_fn="spearman", seed=42, device="cpu", ) metric.update(test_data, tda) out = metric.compute() + + # remove cache directory if it exists + if os.path.exists("./test_cache"): + shutil.rmtree("./test_cache") + assert (out.item() >= -1.0) and (out.item() <= 1.0), "Test failed." assert isinstance(out, torch.Tensor), "Output is not a tensor." diff --git a/tests/utils/test_explain_wrapper.py b/tests/utils/test_explain_wrapper.py index cbc69e92..7c3670b5 100644 --- a/tests/utils/test_explain_wrapper.py +++ b/tests/utils/test_explain_wrapper.py @@ -1,4 +1,5 @@ import os +import shutil import pytest import torch @@ -36,12 +37,15 @@ def test_explain( test_tensor = request.getfixturevalue(test_tensor) explanations_exp = request.getfixturevalue(explanations) explanations = explain( - model, - test_id, - os.path.join("./cache", "test_id"), - method, - dataset, - test_tensor, + model=model, + model_id=test_id, + cache_dir="./test_cache", + method=method, + train_dataset=dataset, + test_tensor=test_tensor, **method_kwargs, ) + # remove cache directory if it exists + if os.path.exists("./test_cache"): + shutil.rmtree("./test_cache") assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" From 868c51aa847fdb37c97f1d4f73cc1eaf3ea8639c Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Mon, 17 Jun 2024 18:46:55 +0200 Subject: [PATCH 4/5] usage testing for cifar10 resnet18 --- .../randomization/model_randomization.py | 8 +- src/utils/explain_wrapper.py | 6 +- tutorials/usage_testing.py | 149 ++++++++++++++++++ 3 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 tutorials/usage_testing.py diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index 51e3f5e1..25ac7db9 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -3,10 +3,10 @@ import torch -from metrics.base import Metric -from utils.common import _get_parent_module_from_name, make_func -from utils.explain_wrapper import ExplainFunc -from utils.functions.correlations import ( +from src.metrics.base import Metric +from src.utils.common import _get_parent_module_from_name, make_func +from src.utils.explain_wrapper import ExplainFunc +from src.utils.functions.correlations import ( CorrelationFnLiterals, correlation_functions, ) diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index bb82460d..81205588 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -62,7 +62,11 @@ def explain( similarity_direction=sim_direction, batch_size=batch_size, ) - topk_idx, topk_val = sim_influence.influence(test_tensor, len(train_dataset))[layer] + topk_idx, topk_val = sim_influence.influence( + inputs=test_tensor, + top_k=len(train_dataset), + # load_src_from_disk=False + )[layer] tda = torch.gather(topk_val, 1, topk_idx) return tda diff --git a/tutorials/usage_testing.py b/tutorials/usage_testing.py new file mode 100644 index 00000000..cafdd0df --- /dev/null +++ b/tutorials/usage_testing.py @@ -0,0 +1,149 @@ +"Larhe chunks of code borrowed from https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb" + +import os +from multiprocessing import freeze_support + +import matplotlib.pyplot as plt +import requests +import torch +import torchvision + +# from torch import nn +# from torch import optim +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.models import resnet18 +from torchvision.utils import make_grid +from tqdm import tqdm + +from src.metrics.randomization.model_randomization import ( + ModelRandomizationMetric, +) +from src.utils.explain_wrapper import explain +from src.utils.functions.similarities import cosine_similarity + +DEVICE = "cpu" # "cuda" if torch.cuda.is_available() else "cpu" +print("Running on device:", DEVICE.upper()) + +# manual random seed is used for dataset partitioning +# to ensure reproducible results across runs +RNG = torch.Generator().manual_seed(42) + +# ++++++++++++++++++++++++++++++++++++++++++ +# #Download dataset and pre-trained model +# ++++++++++++++++++++++++++++++++++++++++++ + + +def main(): + # download and pre-process CIFAR10 + normalize = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) + + train_set = torchvision.datasets.CIFAR10(root="./tutorials/data", train=True, download=True, transform=normalize) + train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2) + + # we split held out data into test and validation set + held_out = torchvision.datasets.CIFAR10(root="./tutorials/data", train=False, download=True, transform=normalize) + test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG) + test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2) + # val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2) + + # download pre-trained weights + local_path = "./tutorials/model_weights_resnet18_cifar10.pth" + if not os.path.exists(local_path): + response = requests.get("https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth") + open(local_path, "wb").write(response.content) + + weights_pretrained = torch.load(local_path, map_location=DEVICE) + + # load model with pre-trained weights + model = resnet18(weights=None, num_classes=10) + model.load_state_dict(weights_pretrained) + model.to(DEVICE) + model.eval() + + # a temporary data loader without normalization, just to show the images + tmp_dl = DataLoader( + torchvision.datasets.CIFAR10(root="./tutorials/data", train=True, download=True, transform=transforms.ToTensor()), + batch_size=16 * 5, + shuffle=False, + ) + images, labels = next(iter(tmp_dl)) + + fig, ax = plt.subplots(figsize=(12, 6)) + plt.title("Sample images from CIFAR10 dataset") + ax.set_xticks([]) + ax.set_yticks([]) + ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0)) + plt.show() + + def accuracy(net, loader): + """Return accuracy on a dataset given by the data loader.""" + correct = 0 + total = 0 + for inputs, targets in loader: + inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) + outputs = net(inputs) + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return correct / total + + print(f"Train set accuracy: {100.0 * accuracy(model, train_loader):0.1f}%") + print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%") + + # ++++++++++++++++++++++++++++++++++++++++++ + # Training configuration + # ++++++++++++++++++++++++++++++++++++++++++ + """ + max_epochs = 5 + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD + optimizer_kwargs = {"lr": 0.1, "momentum": 0.9, "weight_decay": 5e-4} + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + scheduler_kwargs = {"T_max": max_epochs} + """ + # ++++++++++++++++++++++++++++++++++++++++++ + # Computing metrics while generating explanations + # ++++++++++++++++++++++++++++++++++++++++++ + + metric = ModelRandomizationMetric( + model=model, + train_dataset=train_set, + explain_fn=explain, + explain_fn_kwargs={"method": "SimilarityInfluence", "layer": "avgpool"}, + model_id="default_model_id", + cache_dir="./cache", + correlation_fn="spearman", + seed=42, + device="cpu", + ) + + # iterate over test set and feed tensor batches first to explain, then to metric + for i, (data, target) in enumerate(tqdm(test_loader)): + data, target = data.to(DEVICE), target.to(DEVICE) + tda = explain( + model=model, + model_id="default_model_id", + cache_dir="./cache", + method="SimilarityInfluence", + train_dataset=train_set, + test_tensor=data, + layer="avgpool", + similarity_metric=cosine_similarity, + similarity_direction="max", + batch_size=1, + ) + metric.update(data, tda) + + print("Model randomization metric output:", metric.compute().item()) + print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%") + + +if __name__ == "__main__": + freeze_support() + main() From 4ec56f911eca85e4f849c497a165fd56478bd69c Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 21 Jun 2024 18:21:19 +0200 Subject: [PATCH 5/5] fix the mypy errors --- Makefile | 3 ++ pyproject.toml | 3 +- src/downstream_tasks/base.py | 2 +- .../subclass_identification.py | 28 +++++++----- src/explainers/aggregators.py | 6 ++- src/explainers/base.py | 19 ++++++-- src/explainers/functional.py | 6 +-- src/explainers/wrappers/captum_influence.py | 4 +- src/metrics/base.py | 20 +++++++-- src/metrics/localization/identical_class.py | 12 ++--- .../randomization/model_randomization.py | 14 +++--- src/utils/cache.py | 23 +++++----- src/utils/datasets/activation_dataset.py | 2 +- src/utils/datasets/corrupt_label_dataset.py | 3 ++ src/utils/datasets/group_label_dataset.py | 11 +++-- .../datasets/sample_transform_dataset.py | 5 ++- src/utils/datasets/utils.py | 44 ------------------- src/utils/explanations.py | 13 +++--- src/utils/training/trainer.py | 13 ++++-- tests/explainers/test_aggregators.py | 4 +- tox.ini | 2 +- 21 files changed, 122 insertions(+), 115 deletions(-) delete mode 100644 src/utils/datasets/utils.py diff --git a/Makefile b/Makefile index dbbb18a4..1688b5f0 100644 --- a/Makefile +++ b/Makefile @@ -13,5 +13,8 @@ style: find . | grep -E ".pytest_cache" | xargs rm -rf find . | grep -E ".mypy_cache" | xargs rm -rf find . | grep -E ".checkpoints" | xargs rm -rf + find . | grep -E "*eff-info" | xargs rm -rf + find . | grep -E ".build" | xargs rm -rf + find . | grep -E ".htmlcov" | xargs rm -rf find . | grep -E ".lightning_logs" | xargs rm -rf find . -name '*~' -exec rm -f {} + diff --git a/pyproject.toml b/pyproject.toml index f7c14b63..c20728c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ include_trailing_comma = true python_version = "3.11" warn_return_any = false warn_unused_configs = true -ignore_errors = true # TODO: change this +check_untyped_defs = true +#ignore_errors = true # TODO: change this # Black formatting [tool.black] diff --git a/src/downstream_tasks/base.py b/src/downstream_tasks/base.py index bba84cd8..5b07718a 100644 --- a/src/downstream_tasks/base.py +++ b/src/downstream_tasks/base.py @@ -23,7 +23,7 @@ def __init__(self, device: str = "cpu", *args, **kwargs): def evaluate( self, model: torch.nn.Module, - dataset: torch.utils.data.dataset, + dataset: torch.utils.data.Dataset, *args, **kwargs, ): diff --git a/src/downstream_tasks/subclass_identification.py b/src/downstream_tasks/subclass_identification.py index cf3e51eb..fe1cc074 100644 --- a/src/downstream_tasks/subclass_identification.py +++ b/src/downstream_tasks/subclass_identification.py @@ -4,14 +4,14 @@ import lightning as L import torch -from explainers.functional import ExplainFunc -from explainers.wrappers.captum_influence import captum_similarity_explain -from metrics.localization.identical_class import IdenticalClass -from utils.datasets.group_label_dataset import ( +from src.explainers.functional import ExplainFunc +from src.explainers.wrappers.captum_influence import captum_similarity_explain +from src.metrics.localization.identical_class import IdenticalClass +from src.utils.datasets.group_label_dataset import ( ClassToGroupLiterals, GroupLabelDataset, ) -from utils.training.trainer import BaseTrainer, Trainer +from src.utils.training.trainer import BaseTrainer, Trainer class SubclassIdentification: @@ -27,10 +27,12 @@ def __init__( **kwargs, ): self.device = device - self.trainer: Optional[BaseTrainer] = Trainer.from_arguments(model, optimizer, lr, criterion, optimizer_kwargs) + self.trainer: Optional[BaseTrainer] = Trainer.from_arguments( + model=model, optimizer=optimizer, lr=lr, criterion=criterion, optimizer_kwargs=optimizer_kwargs + ) @classmethod - def from_lightning_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): + def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs): obj = cls.__new__(cls) super(SubclassIdentification, obj).__init__() obj.device = device @@ -39,8 +41,8 @@ def from_lightning_module(cls, model: torch.nn.Module, pl_module: L.LightningMod def evaluate( self, - train_dataset: torch.utils.data.dataset, - val_dataset: Optional[torch.utils.data.dataset] = None, + train_dataset: torch.utils.data.Dataset, + val_dataset: Optional[torch.utils.data.Dataset] = None, n_classes: int = 10, n_groups: int = 2, class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", @@ -83,13 +85,15 @@ def evaluate( class_to_group=grouped_dataset.class_to_group, seed=seed, ) - val_loader = torch.utils.data.DataLoader(grouped_val_dataset, batch_size=batch_size) + val_loader: Optional[torch.utils.data.DataLoader] = torch.utils.data.DataLoader( + grouped_val_dataset, batch_size=batch_size + ) else: val_loader = None model = self.trainer.fit( - grouped_train_loader, - val_loader, + train_loader=grouped_train_loader, + val_loader=val_loader, trainer_kwargs=trainer_kwargs, ) metric = IdenticalClass(model=model, train_dataset=train_dataset, device="cpu") diff --git a/src/explainers/aggregators.py b/src/explainers/aggregators.py index 9d1a1a79..de87331b 100644 --- a/src/explainers/aggregators.py +++ b/src/explainers/aggregators.py @@ -23,7 +23,7 @@ def reset(self, *args, **kwargs): """ Used to reset the aggregator state. """ - self.scores: torch.Tensor = None + self.scores = None def load_state_dict(self, state_dict: dict, *args, **kwargs): """ @@ -38,7 +38,9 @@ def state_dict(self, *args, **kwargs): return {"scores": self.scores} def compute(self) -> torch.Tensor: - return self.scores.argsort() + if self.scores is None: + raise ValueError("No scores to aggregate.") + return self.scores class SumAggregator(BaseAggregator): diff --git a/src/explainers/base.py b/src/explainers/base.py index 35691951..88a4807e 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sized, Union import torch @@ -28,8 +28,19 @@ def __init__( def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any): raise NotImplementedError + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.train_dataset, Sized): + return len(self.train_dataset) + dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=1) + return len(dl) + @cache_result - def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch.Tensor: + def self_influence(self, batch_size: int = 32, **kwargs: Any) -> torch.Tensor: """ Base class implements computing self influences by explaining the train dataset one by one @@ -39,10 +50,10 @@ def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch """ # Pre-allcate memory for influences, because torch.cat is slow - influences = torch.empty((len(self.train_dataset),), device=self.device) + influences = torch.empty((self.dataset_length,), device=self.device) ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) - for i, (x, y) in zip(range(0, len(self.train_dataset), batch_size), ldr): + for i, (x, y) in zip(range(0, self.dataset_length, batch_size), ldr): explanations = self.explain(test=x.to(self.device), **kwargs) influences[i : i + batch_size] = explanations.diag(diagonal=i) diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 8e7b33b8..5b2cb0ae 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -10,10 +10,10 @@ def __call__( model_id: str, cache_dir: Optional[str], test_tensor: torch.Tensor, - explanation_targets: Optional[Union[List[int], torch.Tensor]], train_dataset: torch.utils.data.Dataset, - explain_kwargs: Dict, - init_kwargs: Dict, device: Union[str, torch.device], + explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + init_kwargs: Optional[Dict] = None, + explain_kwargs: Optional[Dict] = None, ) -> torch.Tensor: pass diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index 723e0026..58b41e0b 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Union import torch -from captum.influence import SimilarityInfluence +from captum.influence import SimilarityInfluence # type: ignore from src.explainers.base import BaseExplainer from src.explainers.utils import ( @@ -117,7 +117,7 @@ def layer(self, layers: Any): def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any): # We might want to pass the top_k as an argument in some scenarios - top_k = kwargs.get("top_k", len(self.train_dataset)) + top_k = kwargs.get("top_k", self.dataset_length) topk_idx, topk_val = super().explain(test=test, top_k=top_k, **kwargs)[self.layer] inverted_idx = topk_idx.argsort() diff --git a/src/metrics/base.py b/src/metrics/base.py index 909d3905..8e1f5b1d 100644 --- a/src/metrics/base.py +++ b/src/metrics/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Sized import torch @@ -7,14 +8,14 @@ class Metric(ABC): def __init__( self, model: torch.nn.Module, - train_dataset: torch.utils.data.dataset, + train_dataset: torch.utils.data.Dataset, device: str = "cpu", *args, **kwargs, ): - self.model = model.to(device) - self.train_dataset = train_dataset - self.device = device + self.model: torch.nn.Module = model.to(device) + self.train_dataset: torch.utils.data.Dataset = train_dataset + self.device: str = device @abstractmethod def update( @@ -59,3 +60,14 @@ def state_dict(self, *args, **kwargs): """ raise NotImplementedError + + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.train_dataset, Sized): + return len(self.train_dataset) + dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=1) + return len(dl) diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 5ae98f52..ce0a6167 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -1,3 +1,5 @@ +from typing import List + import torch from src.metrics.base import Metric @@ -12,8 +14,8 @@ def __init__( *args, **kwargs, ): - super().__init__(model=model, train_dataset=train_dataset, device=device, *args, **kwargs) - self.scores = [] + super().__init__(model=model, train_dataset=train_dataset, device=device) + self.scores: List[torch.Tensor] = [] def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): """ @@ -65,11 +67,11 @@ def __init__( *args, **kwargs, ): - assert len(subclass_labels) == len(train_dataset), ( + super().__init__(model, train_dataset, device, *args, **kwargs) + assert len(subclass_labels) == self.dataset_length, ( f"Number of subclass labels ({len(subclass_labels)}) " - f"does not match the number of train dataset samples ({len(train_dataset)})." + f"does not match the number of train dataset samples ({self.dataset_length})." ) - super().__init__(model, train_dataset, device, *args, **kwargs) self.subclass_labels = subclass_labels def update(self, test_subclasses: torch.Tensor, explanations: torch.Tensor): diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index e8070f8e..49ab460f 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -1,5 +1,5 @@ import copy -from typing import Callable, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch @@ -22,8 +22,8 @@ def __init__( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, explain_fn: ExplainFunc, - explain_init_kwargs: Optional[dict] = {}, - explain_fn_kwargs: Optional[dict] = {}, + explain_init_kwargs: Optional[dict] = None, + explain_fn_kwargs: Optional[dict] = None, correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman", seed: int = 42, model_id: str = "0", @@ -39,8 +39,8 @@ def __init__( ) self.model = model self.train_dataset = train_dataset - self.explain_fn_kwargs = explain_fn_kwargs - self.explain_init_kwargs = explain_init_kwargs + self.explain_fn_kwargs = explain_fn_kwargs or {} + self.explain_init_kwargs = explain_init_kwargs or {} self.seed = seed self.model_id = model_id self.cache_dir = cache_dir @@ -64,11 +64,11 @@ def __init__( train_dataset=self.train_dataset, ) - self.results = {"scores": []} + self.results: Dict[str, List] = {"scores": []} # TODO: create a validation utility function if isinstance(correlation_fn, str) and correlation_fn in correlation_functions: - self.corr_measure = correlation_functions.get(correlation_fn) + self.corr_measure = correlation_functions[correlation_fn] elif callable(correlation_fn): self.corr_measure = correlation_fn else: diff --git a/src/utils/cache.py b/src/utils/cache.py index 8a7a9d6d..3054c8cf 100644 --- a/src/utils/cache.py +++ b/src/utils/cache.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, Union import torch -from captum.attr import LayerActivation +from captum.attr import LayerActivation # type: ignore from torch import Tensor from torch.utils.data import DataLoader @@ -22,15 +22,15 @@ def __init__(self): pass @staticmethod - def save(**kwargs) -> None: + def save(*args, **kwargs) -> None: raise NotImplementedError @staticmethod - def load(**kwargs) -> Any: + def load(*args, **kwargs) -> Any: raise NotImplementedError @staticmethod - def exists(**kwargs) -> bool: + def exists(*args, **kwargs) -> bool: raise NotImplementedError @@ -49,7 +49,7 @@ def load(path: str, file_id: str, device: str = "cpu") -> Tensor: return torch.load(file_path, map_location=device) @staticmethod - def exists(path: str, file_id: str, num_id: int) -> bool: + def exists(path: str, file_id: str) -> bool: file_path = os.path.join(path, file_id) return os.path.isfile(file_path) @@ -98,7 +98,7 @@ def __init__(self) -> None: @staticmethod def exists( path: str, - layer: Optional[str] = None, + layer: str, num_id: Optional[Union[str, int]] = None, ) -> bool: av_dir = os.path.join(path, layer) @@ -108,7 +108,7 @@ def exists( @staticmethod def save( path: str, - layers: List[str], + layers: str | List[str], act_tensors: List[Tensor], labels: Tensor, num_id: Union[str, int], @@ -128,8 +128,9 @@ def save( @staticmethod def load( path: str, - layer: Optional[str] = None, + layer: str, device: str = "cpu", + **kwargs, ) -> ActivationDataset: layer_dir = os.path.join(path, layer) @@ -145,15 +146,15 @@ def _manage_loading_layers( layers: Union[str, List[str]], load_from_disk: bool = True, num_id: Optional[Union[str, int]] = None, - ) -> List[str]: - unsaved_layers = [] + ) -> str | List[str]: + unsaved_layers: List[str] = [] if load_from_disk: for layer in layers: if not ActivationsCache.exists(path, layer, num_id): unsaved_layers.append(layer) else: - unsaved_layers = layers + unsaved_layers = [layers] if isinstance(layers, str) else layers warnings.warn( "Overwriting activations: load_from_disk is set to False. Removing all " f"activations matching specified parameters {{path: {path}, " diff --git a/src/utils/datasets/activation_dataset.py b/src/utils/datasets/activation_dataset.py index 7f5434a6..9d48f1f5 100644 --- a/src/utils/datasets/activation_dataset.py +++ b/src/utils/datasets/activation_dataset.py @@ -33,7 +33,7 @@ def samples_and_labels(self) -> Tuple[Tensor, Tensor]: samples = [] labels = [] - for sample, label in self: + for sample, label in self: # type: ignore samples.append(sample) labels.append(label) diff --git a/src/utils/datasets/corrupt_label_dataset.py b/src/utils/datasets/corrupt_label_dataset.py index e8cad1ba..c53b7f9b 100644 --- a/src/utils/datasets/corrupt_label_dataset.py +++ b/src/utils/datasets/corrupt_label_dataset.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# type: ignore + import random import torch diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py index f605e5c4..d29fef9c 100644 --- a/src/utils/datasets/group_label_dataset.py +++ b/src/utils/datasets/group_label_dataset.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional, Sized, Union import torch from torch.utils.data import Dataset @@ -7,7 +7,7 @@ ClassToGroupLiterals = Literal["random"] -class GroupLabelDataset: +class GroupLabelDataset(Dataset): def __init__( self, dataset: Dataset, @@ -15,7 +15,7 @@ def __init__( n_groups: int = 2, class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random", seed: Optional[int] = 27, - device: int = "cpu", + device: str = "cpu", ): self.dataset = dataset self.n_classes = n_classes @@ -47,4 +47,7 @@ def get_subclass_label(self, index): return y def __len__(self): - return len(self.dataset) + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) diff --git a/src/utils/datasets/sample_transform_dataset.py b/src/utils/datasets/sample_transform_dataset.py index 7f9cc226..ecc642e3 100644 --- a/src/utils/datasets/sample_transform_dataset.py +++ b/src/utils/datasets/sample_transform_dataset.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# type: ignore + from typing import Callable, List, Optional, Union import torch @@ -37,7 +40,7 @@ def __init__( self.mark_indices = IC.load(path="./datasets", file_id=f"{dataset_id}_mark_ids") else: self.mark_indices = self.get_mark_sample_ids() - IC.save(path=cache_path, file_id=f"{dataset_id}_mark_ids") + IC.save(path=cache_path, file_id=f"{dataset_id}_mark_ids", indices=self.mark_indices) def __len__(self): return len(self.dataset) diff --git a/src/utils/datasets/utils.py b/src/utils/datasets/utils.py deleted file mode 100644 index 4043adc4..00000000 --- a/src/utils/datasets/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -from torchvision.datasets import CIFAR10, MNIST, FashionMNIST - -from src.utils.datasets.corrupt_label_dataset import CorruptLabelDataset -from src.utils.datasets.group_label_dataset import GroupLabelDataset -from src.utils.datasets.reduced_label_dataset import ReduceLabelDataset -from src.utils.datasets.sample_transform_dataset import MarkDataset - - -def load_datasets(dataset_name, dataset_type, **kwparams): - ds_dict = {"MNIST": MNIST, "CIFAR": CIFAR10, "FashionMNIST": FashionMNIST} - if "only_train" not in kwparams.keys(): - only_train = False - else: - only_train = kwparams["only_train"] - data_root = kwparams["data_root"] - class_groups = kwparams["class_groups"] - validation_size = kwparams["validation_size"] - set = kwparams["image_set"] - - if dataset_name in ds_dict.keys(): - dscls = ds_dict[dataset_name] - ds = dscls(root=data_root, split="train", validation_size=validation_size) - evalds = dscls(root=data_root, split=set, validation_size=validation_size) - else: - raise NameError(f"Unresolved dataset name : {dataset_name}.") - if dataset_type == "group": - ds = GroupLabelDataset(ds, class_groups=class_groups) - evalds = GroupLabelDataset(evalds, class_groups=class_groups) - elif dataset_type == "corrupt": - ds = CorruptLabelDataset(ds) - evalds = CorruptLabelDataset(evalds) - elif dataset_type == "mark": - ds = MarkDataset(ds, only_train=only_train) - evalds = MarkDataset(evalds, only_train=only_train) - # assert ds is not None and evalds is not None - return ds, evalds - - -def load_datasets_reduced(dataset_name, dataset_type, kwparams): - ds, evalds = load_datasets(dataset_name, dataset_type, **kwparams) - if dataset_type in ["group", "corrupt"]: - ds = ReduceLabelDataset(ds) - evalds = ReduceLabelDataset(evalds) - return ds, evalds diff --git a/src/utils/explanations.py b/src/utils/explanations.py index 90d2776a..cb11fde6 100644 --- a/src/utils/explanations.py +++ b/src/utils/explanations.py @@ -1,6 +1,5 @@ import glob import os -from typing import Optional, Tuple, Union import torch @@ -21,10 +20,10 @@ def __init__( """ pass - def __getitem__(self, index: Union[int, slice]) -> torch.Tensor: + def __getitem__(self, index: int) -> torch.Tensor: raise NotImplementedError - def __setitem__(self, index: Union[int, slice], val: Tuple[torch.Tensor, torch.Tensor]): + def __setitem__(self, index: int, val: torch.Tensor): raise NotImplementedError def __len__(self) -> int: @@ -35,7 +34,7 @@ class TensorExplanations(Explanations): def __init__( self, tensor: torch.Tensor, - batch_size: Optional[int] = 8, + batch_size: int = 8, device: str = "cpu", ): """ @@ -54,7 +53,7 @@ def __init__( # assert the number of explanation dimensions is 2 and insert extra dimension to emulate batching assert len(self.xpl.shape) == 2, "Explanations object has more than 2 dimensions." - def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor: + def __getitem__(self, idx: int) -> torch.Tensor: """ :param idx: @@ -62,7 +61,7 @@ def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor: """ return self.xpl[idx * self.batch_size : min((idx + 1) * self.batch_size, self.xpl.shape[0])] - def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Tensor]): + def __setitem__(self, idx: int, val: torch.Tensor): """ :param idx: @@ -116,7 +115,7 @@ def __getitem__(self, idx: int) -> torch.Tensor: return xpl - def __setitem__(self, idx: int, val: Tuple[torch.Tensor, torch.Tensor]): + def __setitem__(self, idx: int, val: torch.Tensor): """ :param idx: diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index 508e7271..fbe6b763 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -5,7 +5,7 @@ import lightning as L import torch -from utils.training.base_pl_module import BasicLightningModule +from src.utils.training.base_pl_module import BasicLightningModule class BaseTrainer(metaclass=abc.ABCMeta): @@ -13,7 +13,7 @@ class BaseTrainer(metaclass=abc.ABCMeta): def fit( self, train_loader: torch.utils.data.dataloader.DataLoader, - val_loader: torch.utils.data.dataloader.DataLoader, + val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_kwargs: Optional[dict] = None, *args, **kwargs, @@ -68,11 +68,18 @@ def from_lightning_module( def fit( self, train_loader: torch.utils.data.dataloader.DataLoader, - val_loader: torch.utils.data.dataloader.DataLoader, + val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_kwargs: Optional[dict] = None, *args, **kwargs, ): + if self.model is None: + raise ValueError( + "Lightning module not initialized. Please initialize using from_arguments or from_lightning_module" + ) + if self.module is None: + raise ValueError("Model not initialized. Please initialize using from_arguments or from_lightning_module") + if trainer_kwargs is None: trainer_kwargs = {} trainer = L.Trainer(**trainer_kwargs) diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index d166ca27..043f4bc8 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -18,7 +18,7 @@ def test_sum_aggregator(test_id, explanations, request): explanations = request.getfixturevalue(explanations) aggregator = SumAggregator() aggregator.update(explanations) - global_rank = aggregator.compute() + global_rank = aggregator.compute().argsort() assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) @@ -36,5 +36,5 @@ def test_abs_aggregator(test_id, explanations, request): explanations = request.getfixturevalue(explanations) aggregator = AbsSumAggregator() aggregator.update(explanations) - global_rank = aggregator.compute() + global_rank = aggregator.compute().argsort() assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) diff --git a/tox.ini b/tox.ini index 355fcf94..b7363357 100644 --- a/tox.ini +++ b/tox.ini @@ -37,4 +37,4 @@ deps = {[testenv]deps} mypy==1.9.0 commands = - python3 -m mypy src + python3 -m mypy src --check-untyped-defs