From e9b05548f320308670c246d59252c92e8abe0f30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sat, 15 Jun 2024 00:31:04 +0200 Subject: [PATCH 01/30] base class --- src/explainers/base.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/explainers/base.py b/src/explainers/base.py index c0f7665b..5a8e9f69 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -1,25 +1,49 @@ from abc import ABC, abstractmethod -from typing import Union +from typing import Optional, Union import torch class Explainer(ABC): - def __init__(self, model: torch.nn.Module, dataset: torch.data.utils.Dataset, device: Union[str, torch.device]): + def __init__( + self, + model: torch.nn.Module, + model_id: str, + train_dataset: torch.data.utils.Dataset, + device: Union[str, torch.device], + **kwargs, + ): self.model = model self.device = torch.device(device) if isinstance(device, str) else device - self.images = dataset + self.train_dataset = train_dataset self.samples = [] self.labels = [] + self._self_influences = None dev = torch.device(device) self.model.to(dev) @abstractmethod - def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor: - pass + def explain(self, test: torch.Tensor, **kwargs): + raise NotImplementedError - def train(self) -> None: - pass + @abstractmethod + def load_state_dict(self, path): + raise NotImplementedError + + def state_dict(self): + raise NotImplementedError + + @abstractmethod + def reset(self): + raise NotImplementedError - def save_coefs(self, dir_path: str) -> None: - pass + def self_influences(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + if self._self_influences is None: + self._self_influences = torch.empty((len(self.train_dataset),), device=self.device) + ldr = torch.nn.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) + for i, (x, y) in iter(ldr): + upper_index = i * batch_size + x.shape[0] + explanations = self.explain(test=x, **kwargs) + explanations = explanations[:, i:upper_index] + self._self_influences[i:upper_index] = torch.diag(explanations) + return self._self_influences From 0ee06fe774a0f15fb43d6e382fdcb254926b14cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sat, 15 Jun 2024 00:40:24 +0200 Subject: [PATCH 02/30] small changes --- src/explainers/feature_kernel_explainer.py | 70 ------------------ src/explainers/gradient_product_explainer.py | 75 -------------------- 2 files changed, 145 deletions(-) delete mode 100644 src/explainers/feature_kernel_explainer.py delete mode 100644 src/explainers/gradient_product_explainer.py diff --git a/src/explainers/feature_kernel_explainer.py b/src/explainers/feature_kernel_explainer.py deleted file mode 100644 index 16adb8c2..00000000 --- a/src/explainers/feature_kernel_explainer.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -from typing import Union - -import torch - -from src.explainers.base import Explainer -from src.utils.cache import ActivationsCache as AC - - -class FeatureKernelExplainer(Explainer): - def __init__( - self, - model: torch.nn.Module, - layer: str, - dataset: torch.data.utils.Dataset, - device: Union[str, torch.device], - file_path: str, - normalize: bool = True, - ): - """ - - :param model: - :param dataset: - :param device: - :param file_path: - :param normalize: - """ - super().__init__(model, dataset, device) - - layer = "features" # TODO: should be configurable - self.coefficients = None # the coefficients for each training datapoint x class - self.learned_weights = None - self.normalize = normalize - - self.samples, self.labels = self.generate_features(model, dataset, layer, file_path) - self.mean = self.samples.sum(0) / self.samples.shape[0] - self.stdvar = torch.sqrt(torch.sum((self.samples - self.mean) ** 2, dim=0) / self.samples.shape[0]) - self.normalized_samples = self.normalize_features(self.samples) if self.normalize else self.samples - - @staticmethod - def generate_features(model, dataset, layer, file_path): - dataloader = torch.utils.data.DataLoader(dataset, batch_size=32) - av_dataset = AC.generate_dataset_activations( - path=file_path, - model=model, - layers=[layer], - dataloader=dataloader, - load_from_disk=True, - return_activations=True, - )[0] - return av_dataset.samples_and_labels - - def normalize_features(self, features: torch.Tensor) -> torch.Tensor: - return (features - self.mean) / self.stdvar - - def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor): - assert self.coefficients is not None # TODO: shouldn't we calculate coefficients in here? - x = x.to(self.device) - f = self.model.features(x) # TODO: make it more flexible wrt. layer name - if self.normalize: - f = self.normalize_features(f) - crosscorr = torch.matmul(f, self.normalized_samples.T) - crosscorr = crosscorr[:, :, None] - xpl = self.coefficients * crosscorr - indices = explanation_targets[:, None, None].expand(-1, self.samples.shape[0], 1) - xpl = torch.gather(xpl, dim=-1, index=indices) - return torch.squeeze(xpl) - - def save_coefs(self, dir: str): - torch.save(self.coefficients, os.path.join(dir, f"{self.name}_coefs")) diff --git a/src/explainers/gradient_product_explainer.py b/src/explainers/gradient_product_explainer.py deleted file mode 100644 index 76ee4f52..00000000 --- a/src/explainers/gradient_product_explainer.py +++ /dev/null @@ -1,75 +0,0 @@ -import time -from typing import Union - -import torch - -from src.explainers.base import Explainer - - -class GradientProductExplainer(Explainer): - name = "GradientProductExplainer" - - def __init__( - self, model: torch.nn.Module, dataset: torch.utils.data.Dataset, device: Union[str, torch.device], loss=None - ): - super().__init__(model, dataset, device) - self.number_of_params = 0 - self.loss = loss - for p in list(self.model.sim_parameters()): - nn = 1 - for s in list(p.size()): - nn = nn * s - self.number_of_params += nn - # USE get_param_grad instead of grad_ds = GradientDataset(self.model, dataset) - self.dataset = dataset - - def get_param_grad(self, x: torch.Tensor, index: int = None): - x = x.to(self.device) - out = self.model(x[None, :, :]) - if index is None: - index = range(self.model.classifier.out_features) - else: - index = [index] - grads = torch.empty(len(index), self.number_of_params) - - for i, ind in enumerate(index): - assert ind > -1 and int(ind) == ind - self.model.zero_grad() - if self.loss is not None: - out_new = self.loss(out, torch.eye(out.shape[1], device=self.device)[None, ind]) - out_new.backward(retain_graph=True) - else: - out[0][ind].backward(retain_graph=True) - cumul = torch.empty(0, device=self.device) - for par in self.model.sim_parameters(): - grad = par.grad.flatten() - cumul = torch.cat((cumul, grad), 0) - grads[i] = cumul - - return torch.squeeze(grads) - - def explain(self, x, preds=None, targets=None): - assert not ((targets is None) and (self.loss is not None)) - xpl = torch.zeros((x.shape[0], len(self.dataset)), dtype=torch.float) - xpl = xpl.to(self.device) - t = time.time() - for j in range(len(self.dataset)): - tr_sample, y = self.dataset[j] - train_grad = self.get_param_grad(tr_sample, y) - train_grad = train_grad / torch.norm(train_grad) - train_grad.to(self.device) - for i in range(x.shape[0]): - if self.loss is None: - test_grad = self.get_param_grad(x[i], preds[i]) - else: - test_grad = self.get_param_grad(x[i], targets[i]) - test_grad.to(self.device) - xpl[i, j] = torch.matmul(train_grad, test_grad) - if j % 1000 == 0: - tdiff = time.time() - t - mins = int(tdiff / 60) - print( - f"{int(j / 1000)}/{int(len(self.dataset) / 1000)}k- 1000 images done in {mins} minutes {tdiff - 60 * mins}" - ) - t = time.time() - return xpl From 9767b307cd31722dc4ac1817d7ce59918e9c7bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sat, 15 Jun 2024 01:40:27 +0200 Subject: [PATCH 03/30] Captum interface and SimilarityInfluence --- src/explainers/aggregators/aggregators.py | 43 ----------------- src/explainers/base.py | 4 +- .../aggregators/test_aggregators.py | 47 ------------------- 3 files changed, 2 insertions(+), 92 deletions(-) delete mode 100644 src/explainers/aggregators/aggregators.py delete mode 100644 tests/explainers/aggregators/test_aggregators.py diff --git a/src/explainers/aggregators/aggregators.py b/src/explainers/aggregators/aggregators.py deleted file mode 100644 index 0d36b6b4..00000000 --- a/src/explainers/aggregators/aggregators.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - - -class ExplanationsAggregator(ABC): - def __init__(self, training_size: int, *args, **kwargs): - self.scores = torch.zeros(training_size) - - @abstractmethod - def update(self, explanations: torch.Tensor): - raise NotImplementedError - - def reset(self, *args, **kwargs): - """ - Used to reset the aggregator state. - """ - self.scores = torch.zeros_like(self.scores) - - def load_state_dict(self, state_dict: dict, *args, **kwargs): - """ - Used to load the aggregator state. - """ - self.scores = state_dict["scores"] - - def state_dict(self, *args, **kwargs): - """ - Used to return the metric state. - """ - return {"scores": self.scores} - - def compute(self) -> torch.Tensor: - return self.scores.argsort() - - -class SumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: - self.scores += explanations.sum(dim=0) - - -class AbsSumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: - self.scores += explanations.abs().sum(dim=0) diff --git a/src/explainers/base.py b/src/explainers/base.py index 5a8e9f69..d29123e0 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -23,7 +23,7 @@ def __init__( self.model.to(dev) @abstractmethod - def explain(self, test: torch.Tensor, **kwargs): + def explain(self, test: torch.Tensor, targets: Union[List[int], torch.Tensor, None], **kwargs): raise NotImplementedError @abstractmethod diff --git a/tests/explainers/aggregators/test_aggregators.py b/tests/explainers/aggregators/test_aggregators.py deleted file mode 100644 index cfab2355..00000000 --- a/tests/explainers/aggregators/test_aggregators.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from src.explainers.aggregators.aggregators import ( - AbsSumAggregator, - SumAggregator, -) - - -@pytest.mark.aggregators -@pytest.mark.parametrize( - "test_id, dataset, explanations", - [ - ( - "mnist", - "load_mnist_dataset", - "load_mnist_explanations_1", - ), - ], -) -def test_sum_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) - explanations = request.getfixturevalue(explanations) - aggregator = SumAggregator(training_size=len(dataset)) - aggregator.update(explanations) - global_rank = aggregator.compute() - assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) - - -@pytest.mark.aggregators -@pytest.mark.parametrize( - "test_id, dataset, explanations", - [ - ( - "mnist", - "load_mnist_dataset", - "load_mnist_explanations_1", - ), - ], -) -def test_abs_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) - explanations = request.getfixturevalue(explanations) - aggregator = AbsSumAggregator(training_size=len(dataset)) - aggregator.update(explanations) - global_rank = aggregator.compute() - assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) From 3e98e9fc48a23290495fa6537032d3b06d91da63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sat, 15 Jun 2024 01:41:20 +0200 Subject: [PATCH 04/30] Last commit actually here --- src/explainers/aggregators.py | 43 +++++++++++++++++++++++++++ src/explainers/captum/base.py | 35 ++++++++++++++++++++++ src/explainers/captum/similarity.py | 28 ++++++++++++++++++ tests/explainers/test_aggregators.py | 44 ++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+) create mode 100644 src/explainers/aggregators.py create mode 100644 src/explainers/captum/base.py create mode 100644 src/explainers/captum/similarity.py create mode 100644 tests/explainers/test_aggregators.py diff --git a/src/explainers/aggregators.py b/src/explainers/aggregators.py new file mode 100644 index 00000000..0d36b6b4 --- /dev/null +++ b/src/explainers/aggregators.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod + +import torch + + +class ExplanationsAggregator(ABC): + def __init__(self, training_size: int, *args, **kwargs): + self.scores = torch.zeros(training_size) + + @abstractmethod + def update(self, explanations: torch.Tensor): + raise NotImplementedError + + def reset(self, *args, **kwargs): + """ + Used to reset the aggregator state. + """ + self.scores = torch.zeros_like(self.scores) + + def load_state_dict(self, state_dict: dict, *args, **kwargs): + """ + Used to load the aggregator state. + """ + self.scores = state_dict["scores"] + + def state_dict(self, *args, **kwargs): + """ + Used to return the metric state. + """ + return {"scores": self.scores} + + def compute(self) -> torch.Tensor: + return self.scores.argsort() + + +class SumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.sum(dim=0) + + +class AbsSumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.abs().sum(dim=0) diff --git a/src/explainers/captum/base.py b/src/explainers/captum/base.py new file mode 100644 index 00000000..6211ffa7 --- /dev/null +++ b/src/explainers/captum/base.py @@ -0,0 +1,35 @@ +from typing import List, Union + +import torch +from captum.influence import DataInfluence + +from src.explainers.base import Explainer + + +class CaptumExplainerWrapper(Explainer): + def __init__( + self, + model: torch.nn.Module, + model_id: str, + train_dataset: torch.data.utils.Dataset, + device: Union[str, torch.device], + explainer_cls: DataInfluence, + **explainer_kwargs, + ): + super().__init__(model=model, model_id=model_id, train_dataset=train_dataset, device=device) + self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_kwargs) + + def explain( + self, test: torch.Tensor, targets: Union[torch.Tensor, List[int], None], **explainer_kwargs + ) -> torch.Tensor: + if targets is not None: + if not isinstance(targets, torch.Tensor): + if isinstance(targets, list): + targets = torch.tensor(targets) + else: + raise TypeError( + f"targets should be of type NoneType, List or torch.Tensor. Got {type(targets)} instead." + ) + return self.captum_explainer.influence(inputs=(test, targets), **explainer_kwargs) + else: + return self.captum_explainer.influence(inputs=test, **explainer_kwargs) diff --git a/src/explainers/captum/similarity.py b/src/explainers/captum/similarity.py new file mode 100644 index 00000000..19afd8e8 --- /dev/null +++ b/src/explainers/captum/similarity.py @@ -0,0 +1,28 @@ +from typing import Union + +import torch +from captum.influence import SimilarityInfluence + +from src.explainers.captum.base import CaptumExplainerWrapper + + +class CaptumSimilarityExplainer(CaptumExplainerWrapper): + def __init__( + self, + model: torch.nn.Module, + model_id: str, + train_dataset: torch.data.utils.Dataset, + device: Union[str, torch.device], + **explainer_kwargs, + ): + super().__init__( + model=model, + model_id=model_id, + train_dataset=train_dataset, + device=device, + explainer_cls=SimilarityInfluence, + **explainer_kwargs, + ) + + def explain(self, test: torch.Tensor) -> torch.Tensor: + return super().explain(test=test, targets=None) diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py new file mode 100644 index 00000000..f800c5c5 --- /dev/null +++ b/tests/explainers/test_aggregators.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from src.explainers.aggregators import AbsSumAggregator, SumAggregator + + +@pytest.mark.aggregators +@pytest.mark.parametrize( + "test_id, dataset, explanations", + [ + ( + "mnist", + "load_mnist_dataset", + "load_mnist_explanations_1", + ), + ], +) +def test_sum_aggregator(test_id, dataset, explanations, request): + dataset = request.getfixturevalue(dataset) + explanations = request.getfixturevalue(explanations) + aggregator = SumAggregator(training_size=len(dataset)) + aggregator.update(explanations) + global_rank = aggregator.compute() + assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) + + +@pytest.mark.aggregators +@pytest.mark.parametrize( + "test_id, dataset, explanations", + [ + ( + "mnist", + "load_mnist_dataset", + "load_mnist_explanations_1", + ), + ], +) +def test_abs_aggregator(test_id, dataset, explanations, request): + dataset = request.getfixturevalue(dataset) + explanations = request.getfixturevalue(explanations) + aggregator = AbsSumAggregator(training_size=len(dataset)) + aggregator.update(explanations) + global_rank = aggregator.compute() + assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) From 367057acdccbd4284ee9a268517d1f4b8b9c139f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Mon, 17 Jun 2024 17:22:14 +0200 Subject: [PATCH 05/30] delete leftover assertion --- src/utils/functions/similarities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils/functions/similarities.py b/src/utils/functions/similarities.py index 699f705a..de781d23 100644 --- a/src/utils/functions/similarities.py +++ b/src/utils/functions/similarities.py @@ -39,7 +39,6 @@ def dot_product_similarity(test, train, replace_nan=0) -> Tensor: # TODO: I don't know why Captum return test activations as a list if isinstance(test, list): test = torch.cat(test) - assert torch.all(test == train) test = test.view(test.shape[0], -1) train = train.view(train.shape[0], -1) From e9e4afe0743ed971ea80d547ecffa39ea7cad7bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Mon, 17 Jun 2024 17:44:14 +0200 Subject: [PATCH 06/30] move self_influence files --- src/explainers/base.py | 11 ++++---- src/explainers/captum/base.py | 25 +++++++++++++------ src/explainers/captum/similarity.py | 6 +++-- .../{aggregators => }/self_influence.py | 0 .../randomization/model_randomization.py | 24 ++++++++++++++---- .../{aggregators => }/test_self_influence.py | 4 +-- tests/metrics/test_randomization_metrics.py | 14 +++++++---- 7 files changed, 56 insertions(+), 28 deletions(-) rename src/explainers/{aggregators => }/self_influence.py (100%) rename tests/explainers/{aggregators => }/test_self_influence.py (92%) diff --git a/src/explainers/base.py b/src/explainers/base.py index d29123e0..627afc7d 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -9,21 +9,20 @@ def __init__( self, model: torch.nn.Module, model_id: str, + cache_dir: Optional[str], train_dataset: torch.data.utils.Dataset, device: Union[str, torch.device], **kwargs, ): - self.model = model self.device = torch.device(device) if isinstance(device, str) else device self.train_dataset = train_dataset - self.samples = [] - self.labels = [] self._self_influences = None - dev = torch.device(device) - self.model.to(dev) + self.model.to(self.device) + self.model_id = model_id + self.cache_dir = cache_dir @abstractmethod - def explain(self, test: torch.Tensor, targets: Union[List[int], torch.Tensor, None], **kwargs): + def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **kwargs): raise NotImplementedError @abstractmethod diff --git a/src/explainers/captum/base.py b/src/explainers/captum/base.py index 6211ffa7..9b9af76c 100644 --- a/src/explainers/captum/base.py +++ b/src/explainers/captum/base.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union import torch from captum.influence import DataInfluence @@ -11,17 +11,27 @@ def __init__( self, model: torch.nn.Module, model_id: str, + cache_dir: Optional[str], train_dataset: torch.data.utils.Dataset, device: Union[str, torch.device], explainer_cls: DataInfluence, - **explainer_kwargs, + **explainer_init_kwargs, ): - super().__init__(model=model, model_id=model_id, train_dataset=train_dataset, device=device) - self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_kwargs) + super().__init__( + model=model, model_id=model_id, train_dataset=train_dataset, device=device, cache_dir=cache_dir + ) + for shared_field_name in ["model_id", "cache_dir"]: + assert shared_field_name not in explainer_init_kwargs.keys(), ( + f"{shared_field_name} is already given to the explainer object, " + "it must not be repeated in the explainer_init_kwargs" + ) + + self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_init_kwargs) def explain( - self, test: torch.Tensor, targets: Union[torch.Tensor, List[int], None], **explainer_kwargs + self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **explain_fn_kwargs ) -> torch.Tensor: + test = test.to(self.device) if targets is not None: if not isinstance(targets, torch.Tensor): if isinstance(targets, list): @@ -30,6 +40,7 @@ def explain( raise TypeError( f"targets should be of type NoneType, List or torch.Tensor. Got {type(targets)} instead." ) - return self.captum_explainer.influence(inputs=(test, targets), **explainer_kwargs) + targets = targets.to(self.device) + return self.captum_explainer.influence(inputs=(test, targets), **explain_fn_kwargs) else: - return self.captum_explainer.influence(inputs=test, **explainer_kwargs) + return self.captum_explainer.influence(inputs=test, **explain_fn_kwargs) diff --git a/src/explainers/captum/similarity.py b/src/explainers/captum/similarity.py index 19afd8e8..8a9be2dd 100644 --- a/src/explainers/captum/similarity.py +++ b/src/explainers/captum/similarity.py @@ -11,17 +11,19 @@ def __init__( self, model: torch.nn.Module, model_id: str, + cache_dir: str, train_dataset: torch.data.utils.Dataset, device: Union[str, torch.device], - **explainer_kwargs, + **explainer_init_kwargs, ): super().__init__( model=model, model_id=model_id, + cache_dir=cache_dir, train_dataset=train_dataset, device=device, explainer_cls=SimilarityInfluence, - **explainer_kwargs, + **explainer_init_kwargs, ) def explain(self, test: torch.Tensor) -> torch.Tensor: diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/self_influence.py similarity index 100% rename from src/explainers/aggregators/self_influence.py rename to src/explainers/self_influence.py diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index 51e3f5e1..d07608c6 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -3,10 +3,11 @@ import torch -from metrics.base import Metric -from utils.common import _get_parent_module_from_name, make_func -from utils.explain_wrapper import ExplainFunc -from utils.functions.correlations import ( +from src.explainers.base import Explainer +from src.explainers.functional import ExplainFunc +from src.metrics.base import Metric +from src.utils.common import _get_parent_module_from_name, make_func +from src.utils.functions.correlations import ( CorrelationFnLiterals, correlation_functions, ) @@ -21,7 +22,8 @@ def __init__( self, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - explain_fn: ExplainFunc, + explainer: Union[ExplainFunc, Explainer], + explainer_init_kwargs: Optional[dict] = None, explain_fn_kwargs: Optional[dict] = None, correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman", seed: int = 42, @@ -52,6 +54,16 @@ def __init__( self.generator = torch.Generator(device=device) self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(model) + + explain_fn = explainer + self.explainer = None + if isinstance(explainer, Explainer): + self.explainer = explainer + explain_fn = explainer.explain + elif not callable(explainer): + raise TypeError( + f"Parameter 'explainer' should be of type Explainer of Callable. Got {type(explainer)} instead." + ) self.explain_fn = make_func( func=explain_fn, func_kwargs=explain_fn_kwargs, @@ -95,6 +107,7 @@ def state_dict(self): "random_model_state_dict": self.model.state_dict(), "seed": self.seed, "generator_state": self.generator.get_state(), + "explainer": self.explainer, "explain_fn": self.explain_fn, } return state_dict @@ -102,6 +115,7 @@ def state_dict(self): def load_state_dict(self, state_dict: dict): self.results = state_dict["results_dict"] self.seed = state_dict["seed"] + self.explainer = state_dict["explainer"] self.explain_fn = state_dict["explain_fn"] self.rand_model.load_state_dict(state_dict["random_model_state_dict"]) self.generator.set_state(state_dict["generator_state"]) diff --git a/tests/explainers/aggregators/test_self_influence.py b/tests/explainers/test_self_influence.py similarity index 92% rename from tests/explainers/aggregators/test_self_influence.py rename to tests/explainers/test_self_influence.py index b43b9dae..6d7d9c73 100644 --- a/tests/explainers/aggregators/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -4,9 +4,7 @@ import torch from torch.utils.data import TensorDataset -from src.explainers.aggregators.self_influence import ( - get_self_influence_ranking, -) +from src.explainers.self_influence import get_self_influence_ranking from src.utils.explain_wrapper import explain from src.utils.functions.similarities import dot_product_similarity diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index c33cc503..636de53a 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,8 +1,10 @@ import pytest import torch -from metrics.randomization.model_randomization import ModelRandomizationMetric -from utils.explain_wrapper import explain +from src.metrics.randomization.model_randomization import ( + ModelRandomizationMetric, +) +from src.utils.explain_wrapper import explain @pytest.mark.randomization_metrics @@ -15,13 +17,13 @@ "load_mnist_dataset", "load_mnist_test_samples_1", 8, - {"method": "SimilarityInfluence", "layer": "fc_2"}, + {"layer": "fc_2"}, "load_mnist_explanations_1", "spearman", ), ], ) -def test_randomization_metric( +def test_randomization_metric_functional( test_id, model, dataset, test_data, batch_size, explain_kwargs, explanations, corr_measure, request ): model = request.getfixturevalue(model) @@ -32,11 +34,13 @@ def test_randomization_metric( model=model, train_dataset=dataset, explain_fn=explain, - explain_fn_kwargs={**explain_kwargs, "layer": "fc_2"}, + explain_fn_kwargs=explain_kwargs, correlation_fn="spearman", seed=42, device="cpu", ) + # TODO: introduce a more meaningful test + # Can we come up with a special attributor that gets exactly 0 score? metric.update(test_data, tda) out = metric.compute() assert (out.item() >= -1.0) and (out.item() <= 1.0), "Test failed." From 5d53cd53d48ed982e18de6fec898237b77098c9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Mon, 17 Jun 2024 19:47:50 +0200 Subject: [PATCH 07/30] randomization test works --- src/explainers/base.py | 6 ++-- src/explainers/captum/base.py | 7 +++-- src/explainers/captum/similarity.py | 24 ++++++++++++++-- .../randomization/model_randomization.py | 28 ++++++++----------- src/utils/common.py | 2 +- tests/metrics/test_randomization_metrics.py | 17 +++++++---- 6 files changed, 54 insertions(+), 30 deletions(-) diff --git a/src/explainers/base.py b/src/explainers/base.py index 627afc7d..804fa477 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -10,13 +10,14 @@ def __init__( model: torch.nn.Module, model_id: str, cache_dir: Optional[str], - train_dataset: torch.data.utils.Dataset, + train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], **kwargs, ): self.device = torch.device(device) if isinstance(device, str) else device self.train_dataset = train_dataset self._self_influences = None + self.model = model self.model.to(self.device) self.model_id = model_id self.cache_dir = cache_dir @@ -37,12 +38,13 @@ def reset(self): raise NotImplementedError def self_influences(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + # Base class implements computing self influences by explaining the train dataset one by one if self._self_influences is None: self._self_influences = torch.empty((len(self.train_dataset),), device=self.device) ldr = torch.nn.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) for i, (x, y) in iter(ldr): upper_index = i * batch_size + x.shape[0] - explanations = self.explain(test=x, **kwargs) + explanations = self.explain(test=x.to(self.device), **kwargs) explanations = explanations[:, i:upper_index] self._self_influences[i:upper_index] = torch.diag(explanations) return self._self_influences diff --git a/src/explainers/captum/base.py b/src/explainers/captum/base.py index 9b9af76c..6cc07ab0 100644 --- a/src/explainers/captum/base.py +++ b/src/explainers/captum/base.py @@ -12,7 +12,7 @@ def __init__( model: torch.nn.Module, model_id: str, cache_dir: Optional[str], - train_dataset: torch.data.utils.Dataset, + train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explainer_cls: DataInfluence, **explainer_init_kwargs, @@ -26,7 +26,10 @@ def __init__( "it must not be repeated in the explainer_init_kwargs" ) - self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_init_kwargs) + self.initialize_captum(explainer_cls, **explainer_init_kwargs) + + def initialize_captum(self, cls, **init_kwargs): + self.captum_explainer = cls(model=self.model, train_dataset=self.train_dataset, **init_kwargs) def explain( self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **explain_fn_kwargs diff --git a/src/explainers/captum/similarity.py b/src/explainers/captum/similarity.py index 8a9be2dd..05178d54 100644 --- a/src/explainers/captum/similarity.py +++ b/src/explainers/captum/similarity.py @@ -12,10 +12,12 @@ def __init__( model: torch.nn.Module, model_id: str, cache_dir: str, - train_dataset: torch.data.utils.Dataset, + train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], + layer: str, **explainer_init_kwargs, ): + self.layer = layer super().__init__( model=model, model_id=model_id, @@ -23,8 +25,26 @@ def __init__( train_dataset=train_dataset, device=device, explainer_cls=SimilarityInfluence, + layers=[layer], **explainer_init_kwargs, ) + def initialize_captum(self, cls, **init_kwargs): + self.captum_explainer = cls( + module=self.model, + influence_src_dataset=self.train_dataset, + activation_dir=self.cache_dir, + model_id=self.model_id, + **init_kwargs, + ) + + def load_state_dict(self, path): + return + + def reset(self): + return + def explain(self, test: torch.Tensor) -> torch.Tensor: - return super().explain(test=test, targets=None) + topk_idx, topk_val = super().explain(test=test, targets=None, top_k=len(self.train_dataset))[self.layer] + tda = torch.gather(topk_val, 1, topk_idx) + return tda diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index d07608c6..90d5a669 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -3,7 +3,6 @@ import torch -from src.explainers.base import Explainer from src.explainers.functional import ExplainFunc from src.metrics.base import Metric from src.utils.common import _get_parent_module_from_name, make_func @@ -22,9 +21,9 @@ def __init__( self, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - explainer: Union[ExplainFunc, Explainer], - explainer_init_kwargs: Optional[dict] = None, - explain_fn_kwargs: Optional[dict] = None, + explain_fn: ExplainFunc, + explain_init_kwargs: Optional[dict] = {}, + explain_fn_kwargs: Optional[dict] = {}, correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman", seed: int = 42, model_id: str = "0", @@ -41,6 +40,7 @@ def __init__( self.model = model self.train_dataset = train_dataset self.explain_fn_kwargs = explain_fn_kwargs + self.explain_init_kwargs = explain_init_kwargs self.seed = seed self.model_id = model_id self.cache_dir = cache_dir @@ -55,18 +55,10 @@ def __init__( self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(model) - explain_fn = explainer - self.explainer = None - if isinstance(explainer, Explainer): - self.explainer = explainer - explain_fn = explainer.explain - elif not callable(explainer): - raise TypeError( - f"Parameter 'explainer' should be of type Explainer of Callable. Got {type(explainer)} instead." - ) self.explain_fn = make_func( func=explain_fn, - func_kwargs=explain_fn_kwargs, + init_kwargs=explain_init_kwargs, + explain_kwargs=explain_fn_kwargs, model_id=self.model_id, cache_dir=self.cache_dir, train_dataset=self.train_dataset, @@ -88,8 +80,12 @@ def update( self, test_data: torch.Tensor, explanations: torch.Tensor, + explanation_targets: torch.Tensor, ): - rand_explanations = self.explain_fn(model=self.rand_model, test_tensor=test_data) + device = "cuda" if torch.cuda.is_available() else "cpu" + rand_explanations = self.explain_fn( + model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=device + ) corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs) @@ -107,7 +103,6 @@ def state_dict(self): "random_model_state_dict": self.model.state_dict(), "seed": self.seed, "generator_state": self.generator.get_state(), - "explainer": self.explainer, "explain_fn": self.explain_fn, } return state_dict @@ -115,7 +110,6 @@ def state_dict(self): def load_state_dict(self, state_dict: dict): self.results = state_dict["results_dict"] self.seed = state_dict["seed"] - self.explainer = state_dict["explainer"] self.explain_fn = state_dict["explain_fn"] self.rand_model.load_state_dict(state_dict["random_model_state_dict"]) self.generator.set_state(state_dict["generator_state"]) diff --git a/src/utils/common.py b/src/utils/common.py index a1dcae3c..8de51d1d 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -15,7 +15,7 @@ def _get_parent_module_from_name(model: torch.nn.Module, layer_name: str) -> Any return reduce(getattr, layer_name.split(".")[:-1], model) -def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) -> functools.partial: +def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None = None, **kwargs) -> functools.partial: """A function for creating a partial function with the given arguments.""" if func_kwargs is not None: _func_kwargs = kwargs.copy() diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 636de53a..2121c229 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,10 +1,10 @@ import pytest import torch +from src.explainers.functional import captum_similarity_explain from src.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from src.utils.explain_wrapper import explain @pytest.mark.randomization_metrics @@ -17,31 +17,36 @@ "load_mnist_dataset", "load_mnist_test_samples_1", 8, - {"layer": "fc_2"}, + { + "layer": "fc_2", + }, "load_mnist_explanations_1", + "load_mnist_test_labels_1", "spearman", ), ], ) def test_randomization_metric_functional( - test_id, model, dataset, test_data, batch_size, explain_kwargs, explanations, corr_measure, request + test_id, model, dataset, test_data, batch_size, explain_init_kwargs, explanations, test_labels, request ): model = request.getfixturevalue(model) test_data = request.getfixturevalue(test_data) dataset = request.getfixturevalue(dataset) + explain_init_kwargs = request.getfixturevalue(explain_init_kwargs) + test_labels = request.getfixturevalue(test_labels) tda = request.getfixturevalue(explanations) metric = ModelRandomizationMetric( model=model, train_dataset=dataset, - explain_fn=explain, - explain_fn_kwargs=explain_kwargs, + explain_fn=captum_similarity_explain, + explain_init_kwargs=explain_init_kwargs, correlation_fn="spearman", seed=42, device="cpu", ) # TODO: introduce a more meaningful test # Can we come up with a special attributor that gets exactly 0 score? - metric.update(test_data, tda) + metric.update(test_data=test_data, explanations=tda, explanation_targets=test_labels) out = metric.compute() assert (out.item() >= -1.0) and (out.item() <= 1.0), "Test failed." assert isinstance(out, torch.Tensor), "Output is not a tensor." From b63ffdd7474c966855fe6249fce1440f2cfa6c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 02:11:20 +0200 Subject: [PATCH 08/30] fix similarity explainer output construciton logic --- src/explainers/base.py | 12 ++--- src/explainers/captum/similarity.py | 4 +- src/utils/explain_wrapper.py | 68 ----------------------------- 3 files changed, 9 insertions(+), 75 deletions(-) delete mode 100644 src/utils/explain_wrapper.py diff --git a/src/explainers/base.py b/src/explainers/base.py index 804fa477..510b8d0d 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -37,14 +37,14 @@ def state_dict(self): def reset(self): raise NotImplementedError - def self_influences(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + def self_influence_ranking(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: # Base class implements computing self influences by explaining the train dataset one by one if self._self_influences is None: self._self_influences = torch.empty((len(self.train_dataset),), device=self.device) - ldr = torch.nn.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) - for i, (x, y) in iter(ldr): + ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) + for i, (x, y) in enumerate(iter(ldr)): upper_index = i * batch_size + x.shape[0] explanations = self.explain(test=x.to(self.device), **kwargs) - explanations = explanations[:, i:upper_index] - self._self_influences[i:upper_index] = torch.diag(explanations) - return self._self_influences + explanations = explanations[:, i * batch_size : upper_index] + self._self_influences[i * batch_size : upper_index] = explanations.diag() + return self._self_influences.argsort() diff --git a/src/explainers/captum/similarity.py b/src/explainers/captum/similarity.py index 05178d54..0abc5358 100644 --- a/src/explainers/captum/similarity.py +++ b/src/explainers/captum/similarity.py @@ -35,6 +35,7 @@ def initialize_captum(self, cls, **init_kwargs): influence_src_dataset=self.train_dataset, activation_dir=self.cache_dir, model_id=self.model_id, + similarity_direction="max", **init_kwargs, ) @@ -46,5 +47,6 @@ def reset(self): def explain(self, test: torch.Tensor) -> torch.Tensor: topk_idx, topk_val = super().explain(test=test, targets=None, top_k=len(self.train_dataset))[self.layer] - tda = torch.gather(topk_val, 1, topk_idx) + inverted_idx = topk_idx.argsort() + tda = torch.cat([topk_val[None, i, inverted_idx[i]] for i in range(topk_idx.shape[0])], dim=0) return tda diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py deleted file mode 100644 index bb82460d..00000000 --- a/src/utils/explain_wrapper.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import List, Optional, Protocol, Union - -import torch -from captum.influence import SimilarityInfluence - -from src.utils.datasets.indexed_subset import IndexedSubset -from src.utils.functions.similarities import cosine_similarity - - -class ExplainFunc(Protocol): - def __call__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - method: str, - test_tensor: torch.Tensor, - train_dataset: torch.utils.data.Dataset, - train_ids: Optional[Union[List[int], torch.Tensor]] = None, - ) -> torch.Tensor: - pass - - -def explain( - model: torch.nn.Module, - model_id: str, - cache_dir: str, - method: str, - train_dataset: torch.utils.data.Dataset, - test_tensor: torch.Tensor, - test_target: Optional[torch.Tensor] = None, - train_ids: Optional[Union[List[int], torch.Tensor]] = None, - **kwargs, -) -> torch.Tensor: - """ - Return influential examples for test_tensor retrieved from train_dataset for each test example represented through - a tensor. - :param model: - :param model_id: - :param cache_dir: - :param train_dataset: - :param test_tensor: - :param method: - :param kwargs: - :return: - """ - if method == "SimilarityInfluence": - if train_ids is not None: - train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids) - layer = kwargs.get("layer", "features") - sim_metric = kwargs.get("similarity_metric", cosine_similarity) - sim_direction = kwargs.get("similarity_direction", "max") - batch_size = kwargs.get("batch_size", 1) - - sim_influence = SimilarityInfluence( - module=model, - layers=layer, - influence_src_dataset=train_dataset, - activation_dir=cache_dir, - model_id=model_id, - similarity_metric=sim_metric, - similarity_direction=sim_direction, - batch_size=batch_size, - ) - topk_idx, topk_val = sim_influence.influence(test_tensor, len(train_dataset))[layer] - tda = torch.gather(topk_val, 1, topk_idx) - - return tda From 09c73e747c05c11a74fddbd2d2163e0b73bbd088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 02:12:19 +0200 Subject: [PATCH 09/30] update with correct cos distance explanations --- .../mnist_SimilarityInfluence_tda.pt | Bin 1674 -> 1674 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt b/tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt index 50e65d6e11dd2899c92f6c8c07f035ed095a420d..104654431a67c0ce1a1dbe6f4c0e077312309603 100644 GIT binary patch delta 421 zcmV;W0b2fw4T=q*0)P65z&>-qB|j?zuRf+avOYEOG13%myB0eCe zC_nU;A3wen6F>4l#y&Pn+CDaQ@jkGaJw64`13xqGqdtV~&^}lc0Y3`6)II=CB|ok| z5uHX1*@ z$`C*H6aqgVvE@E0;txN~@Toq|zT`fr=m=ZxwvdupA+x9*ORt!H;pQS!x_1Qjb^rJqg%91{D!2&;?3+6uD zLghZttL{FNhfuseh7{We!nF@ z)#}AQSaY&IPzALwWx_ PH3gvr>Hr}6liCHaKO3&a delta 421 zcmV;W0b2fw4T=q*0)MAEvOaRbB|kGFnm#H6uRd8?EkE@_);>LF13&tRz&`YrA3who z6F)X}@jj)PJwEV1#y&Mm+CJMIB0e3bC_gUmqdpS5)INso&^`>5C_e_z13y?40Y9!k z5u$`C&!vE@FlN@703HX1*Khp9fpS<*g&E+jvOlPEuYnX5jl=m=ZxRLghYe^rJp$_1Qj@hrB*-!2&<3 z%91|QtL{Fb3sB}hyj2T7#H3n27$D+4>q$&LJXxwf=#*7HNbrb0)_7MxPX)C;fWIX_ zTywHMSUd1O>F%;Vu$R6*h`l*K*XqSSP)i30sI?+KKmh;%Kmn7`100iT1J^GwG&eCa zGdDLeIW{*jIX5*pGBz+UFfcJVF*Y|jI5spmG&wOeH!@I52MFOh1A| From c5ddc0491ca953f36050c1cb349dfe3b25df7008 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 02:13:23 +0200 Subject: [PATCH 10/30] fix and add tests --- tests/explainers/test_self_influence.py | 41 +++++++++++++++++++------ tests/metrics/test_unnamed_metrics.py | 2 +- tests/utils/test_explain_wrapper.py | 38 ----------------------- 3 files changed, 32 insertions(+), 49 deletions(-) delete mode 100644 tests/utils/test_explain_wrapper.py diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index 6d7d9c73..918b0c10 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -4,8 +4,8 @@ import torch from torch.utils.data import TensorDataset -from src.explainers.self_influence import get_self_influence_ranking -from src.utils.explain_wrapper import explain +from src.explainers.captum.similarity import CaptumSimilarityExplainer +from src.explainers.functional import captum_similarity_self_influence_ranking from src.utils.functions.similarities import dot_product_similarity @@ -15,22 +15,43 @@ [ ( "random_data", - {"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity}, + {"layer": "identity", "similarity_metric": dot_product_similarity}, ), ], ) -def test_self_influence_ranking(test_id, explain_kwargs, request): +def test_self_influence(test_id, init_kwargs, request): model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) + # X=torch.randn(1,200) + import os + import shutil + + os.mkdir("temp_captum") + os.mkdir("temp_captum2") + torch.random.manual_seed(42) X = torch.randn(100, 200) - rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,))) + # rand_dataset = TensorDataset(X,torch.randint(0,10,(1,))) + y = torch.randint(0, 10, (100,)) + rand_dataset = TensorDataset(X, y) + init_kwargs = {"layer": "identity", "similarity_metric": dot_product_similarity} - self_influence_rank = get_self_influence_ranking( + self_influence_rank_functional = captum_similarity_self_influence_ranking( model=model, model_id="0", cache_dir="temp_captum", - training_data=rand_dataset, - explain_fn=explain, - explain_fn_kwargs=explain_kwargs, + train_dataset=rand_dataset, + init_kwargs=init_kwargs, + device="cpu", + ) + + explainer_obj = CaptumSimilarityExplainer( + model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", **init_kwargs ) + self_influence_rank_stateful = explainer_obj.self_influence_ranking() + + if os.path.isdir("temp_captum2"): + shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) + if os.path.isdir("temp_captum"): + shutil.rmtree(os.path.join(os.getcwd(), "temp_captum")) - assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort()) + assert torch.allclose(self_influence_rank_functional, torch.linalg.norm(X, dim=-1).argsort()) + assert torch.allclose(self_influence_rank_functional, self_influence_rank_stateful) diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 5b9c6e4c..492bf3d5 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( "test_id, model, dataset, top_k, batch_size, explanations, expected_score", [ - ("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1", 8), + ("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1", 7), ], ) def test_top_k_overlap_metrics(test_id, model, dataset, top_k, batch_size, explanations, expected_score, request): diff --git a/tests/utils/test_explain_wrapper.py b/tests/utils/test_explain_wrapper.py deleted file mode 100644 index 1e03136c..00000000 --- a/tests/utils/test_explain_wrapper.py +++ /dev/null @@ -1,38 +0,0 @@ -import os - -import pytest -import torch - -from src.utils.explain_wrapper import explain - - -@pytest.mark.explainers -@pytest.mark.parametrize( - "test_id, model, dataset, test_tensor, method, method_kwargs, explanations", - [ - ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - "load_mnist_test_samples_1", - "SimilarityInfluence", - {"layer": "relu_4"}, - "load_mnist_explanations_1", - ), - ], -) -def test_explain(test_id, model, dataset, explanations, test_tensor, method, method_kwargs, request): - model = request.getfixturevalue(model) - dataset = request.getfixturevalue(dataset) - test_tensor = request.getfixturevalue(test_tensor) - explanations_exp = request.getfixturevalue(explanations) - explanations = explain( - model, - test_id, - os.path.join("./cache", "test_id"), - method, - dataset, - test_tensor, - **method_kwargs, - ) - assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" From 779c0ff8e235d7dbb3e7d95f6ef4018685203276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 12:29:14 +0200 Subject: [PATCH 11/30] Add missing files --- src/explainers/functional.py | 109 ++++++++++++++++++++++++++++ tests/explainers/test_explainers.py | 77 ++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 src/explainers/functional.py create mode 100644 tests/explainers/test_explainers.py diff --git a/src/explainers/functional.py b/src/explainers/functional.py new file mode 100644 index 00000000..6f1b3a9a --- /dev/null +++ b/src/explainers/functional.py @@ -0,0 +1,109 @@ +from typing import Dict, List, Optional, Protocol, Union + +import torch + +from src.explainers.base import Explainer +from src.explainers.captum.similarity import CaptumSimilarityExplainer + + +class ExplainFunc(Protocol): + def __call__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + explanation_targets: Optional[Union[List[int], torch.Tensor]], + train_dataset: torch.utils.data.Dataset, + explain_kwargs: Dict, + init_kwargs: Dict, + device: Union[str, torch.device], + ) -> torch.Tensor: + pass + + +def explainer_functional_interface( + explainer_cls: Explainer, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + explanation_targets: Optional[Union[List[int], torch.Tensor]], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + init_kwargs: Optional[Dict] = {}, + explain_kwargs: Optional[Dict] = {}, +) -> torch.Tensor: + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer.explain(test=test_tensor, **explain_kwargs) + + +def explainer_self_influence_interface( + explainer_cls: Explainer, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + init_kwargs: Dict, + device: Union[str, torch.device], +) -> torch.Tensor: + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer.self_influence_ranking() + + +def captum_similarity_explain( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + explanation_targets: Optional[Union[List[int], torch.Tensor]], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + init_kwargs: Optional[Dict] = {}, + explain_kwargs: Optional[Dict] = {}, +) -> torch.Tensor: + return explainer_functional_interface( + explainer_cls=CaptumSimilarityExplainer, + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=test_tensor, + explanation_targets=explanation_targets, + train_dataset=train_dataset, + device=device, + init_kwargs=init_kwargs, + explain_kwargs=explain_kwargs, + ) + + +def captum_similarity_self_influence_ranking( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + init_kwargs: Dict, + device: Union[str, torch.device], +) -> torch.Tensor: + return explainer_self_influence_interface( + explainer_cls=CaptumSimilarityExplainer, + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + init_kwargs=init_kwargs, + ) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py new file mode 100644 index 00000000..2bc49b46 --- /dev/null +++ b/tests/explainers/test_explainers.py @@ -0,0 +1,77 @@ +import os + +import pytest +import torch + +from src.explainers.captum.similarity import CaptumSimilarityExplainer +from src.explainers.functional import captum_similarity_explain +from src.utils.functions.similarities import cosine_similarity + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, method, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + "SimilarityInfluence", + {"layer": "relu_4"}, + "load_mnist_explanations_1", + ), + ], +) +def test_explain_functional(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + explanations = captum_similarity_explain( + model, + "test_id", + os.path.join("./cache", "test_id"), + test_tensor, + test_labels, + dataset, + device="cpu", + init_kwargs=method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, method, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + "SimilarityInfluence", + {"layer": "relu_4", "similarity_metric": cosine_similarity}, + "load_mnist_explanations_1", + ), + ], +) +def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + explainer = CaptumSimilarityExplainer( + model=model, + model_id="test_id", + cache_dir=os.path.join("./cache", "test_id"), + train_dataset=dataset, + device="cpu", + **method_kwargs, + ) + explanations = explainer.explain(test_tensor, test_labels) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" From 299b4d3b5fda93e696f8e8e7b282bf89db6650b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 12:38:32 +0200 Subject: [PATCH 12/30] fix test script --- tests/metrics/test_randomization_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 2121c229..79b7d4c7 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -22,7 +22,6 @@ }, "load_mnist_explanations_1", "load_mnist_test_labels_1", - "spearman", ), ], ) From 434328208c76bd15381bfbc7628382fd6123e5b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 12:48:38 +0200 Subject: [PATCH 13/30] fix test headers --- tests/explainers/test_explainers.py | 5 ++--- tests/explainers/test_self_influence.py | 2 +- tests/metrics/test_randomization_metrics.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index 2bc49b46..dbbe6989 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -10,7 +10,7 @@ @pytest.mark.explainers @pytest.mark.parametrize( - "test_id, model, dataset, test_tensor, method, method_kwargs, explanations", + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", [ ( "mnist", @@ -18,13 +18,12 @@ "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - "SimilarityInfluence", {"layer": "relu_4"}, "load_mnist_explanations_1", ), ], ) -def test_explain_functional(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): +def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index 918b0c10..3faf9018 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -11,7 +11,7 @@ @pytest.mark.self_influence @pytest.mark.parametrize( - "test_id, explain_kwargs", + "test_id, init_kwargs", [ ( "random_data", diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 79b7d4c7..7d7edb20 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -9,7 +9,7 @@ @pytest.mark.randomization_metrics @pytest.mark.parametrize( - "test_id, model, dataset, test_data, batch_size, explain_kwargs, explanations, corr_measure", + "test_id, model, dataset, test_data, batch_size, explain_init_kwargs, explanations, test_labels", [ ( "mnist", From 0a09a1ab827b6fab1e5cf01dbcbda9160c80aeea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 12:50:48 +0200 Subject: [PATCH 14/30] delete unneeded file --- src/explainers/self_influence.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 src/explainers/self_influence.py diff --git a/src/explainers/self_influence.py b/src/explainers/self_influence.py deleted file mode 100644 index 96406e62..00000000 --- a/src/explainers/self_influence.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional - -import torch - -from utils.explain_wrapper import ExplainFunc - - -def get_self_influence_ranking( - model: torch.nn.Module, - model_id: str, - cache_dir: str, - training_data: torch.utils.data.Dataset, - explain_fn: ExplainFunc, - explain_fn_kwargs: Optional[dict] = None, -) -> torch.Tensor: - size = len(training_data) - self_inf = torch.zeros((size,)) - - for i, (x, y) in enumerate(training_data): - self_inf[i] = explain_fn( - model=model, - model_id=f"{model_id}_id_{i}", - cache_dir=cache_dir, - test_tensor=x[None], - test_label=y[None], - train_dataset=training_data, - train_ids=[i], - **explain_fn_kwargs, - ) - return self_inf.argsort() From 19a7b3fe229560d7fdf448f81c3e97b189f2ab12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 13:12:25 +0200 Subject: [PATCH 15/30] fix test --- tests/explainers/test_explainers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index dbbe6989..6d343b16 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -50,11 +50,10 @@ def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, m "mnist", "load_mnist_model", "load_mnist_dataset", + "load_mnist_explanations_1", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - "SimilarityInfluence", {"layer": "relu_4", "similarity_metric": cosine_similarity}, - "load_mnist_explanations_1", ), ], ) From ee8d8eb6334c481cad5d5c5554e67ea581df345c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 18 Jun 2024 13:48:07 +0200 Subject: [PATCH 16/30] fix test --- tests/explainers/test_explainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index 6d343b16..fd8be80a 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -44,7 +44,7 @@ def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, m @pytest.mark.explainers @pytest.mark.parametrize( - "test_id, model, dataset, test_tensor, method, method_kwargs, explanations", + "test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs", [ ( "mnist", From 888aefd0eb826700883474666635919d183219f9 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Tue, 18 Jun 2024 21:06:17 +0200 Subject: [PATCH 17/30] fixing explainer tests --- src/explainers/aggregators.py | 27 ++++++++---- src/explainers/{base.py => base_explainer.py} | 35 ++++++++------- src/explainers/functional.py | 8 ++-- src/explainers/wrappers/__init__.py | 0 src/explainers/{captum => wrappers}/base.py | 6 +-- .../{captum => wrappers}/similarity.py | 5 ++- src/utils/common.py | 12 +++++ tests/explainers/test_base_explainer.py | 44 +++++++++++++++++++ tests/explainers/test_explainers.py | 6 +-- tests/explainers/test_self_influence.py | 2 +- 10 files changed, 109 insertions(+), 36 deletions(-) rename src/explainers/{base.py => base_explainer.py} (53%) create mode 100644 src/explainers/wrappers/__init__.py rename src/explainers/{captum => wrappers}/base.py (92%) rename src/explainers/{captum => wrappers}/similarity.py (93%) create mode 100644 tests/explainers/test_base_explainer.py diff --git a/src/explainers/aggregators.py b/src/explainers/aggregators.py index 0d36b6b4..4d690f67 100644 --- a/src/explainers/aggregators.py +++ b/src/explainers/aggregators.py @@ -3,19 +3,28 @@ import torch -class ExplanationsAggregator(ABC): - def __init__(self, training_size: int, *args, **kwargs): - self.scores = torch.zeros(training_size) +class BaseAggregator(ABC): + def __init__(self): + self.scores: torch.Tensor = None @abstractmethod def update(self, explanations: torch.Tensor): raise NotImplementedError + def _validate_explanations(self, explanations: torch.Tensor): + if self.scores is None: + self.scores = torch.zeros(explanations.shape[1]) + + if explanations.shape[1] != self.scores.shape[0]: + raise ValueError( + f"Explanations shape {explanations.shape} does not match the expected shape {self.scores.shape}" + ) + def reset(self, *args, **kwargs): """ Used to reset the aggregator state. """ - self.scores = torch.zeros_like(self.scores) + self.scores: torch.Tensor = None def load_state_dict(self, state_dict: dict, *args, **kwargs): """ @@ -33,11 +42,13 @@ def compute(self) -> torch.Tensor: return self.scores.argsort() -class SumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: +class SumAggregator(BaseAggregator): + def update(self, explanations: torch.Tensor): + self._validate_explanations(explanations) self.scores += explanations.sum(dim=0) -class AbsSumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: +class AbsSumAggregator(BaseAggregator): + def update(self, explanations: torch.Tensor): + self._validate_explanations(explanations) self.scores += explanations.abs().sum(dim=0) diff --git a/src/explainers/base.py b/src/explainers/base_explainer.py similarity index 53% rename from src/explainers/base.py rename to src/explainers/base_explainer.py index 510b8d0d..25c59624 100644 --- a/src/explainers/base.py +++ b/src/explainers/base_explainer.py @@ -3,8 +3,10 @@ import torch +from utils.common import cache_result -class Explainer(ABC): + +class BaseExplainer(ABC): def __init__( self, model: torch.nn.Module, @@ -14,22 +16,23 @@ def __init__( device: Union[str, torch.device], **kwargs, ): - self.device = torch.device(device) if isinstance(device, str) else device - self.train_dataset = train_dataset - self._self_influences = None self.model = model - self.model.to(self.device) + self.model.to(device) + self.model_id = model_id self.cache_dir = cache_dir + self.train_dataset = train_dataset + self.device = torch.device(device) if isinstance(device, str) else device @abstractmethod - def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **kwargs): + def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs): raise NotImplementedError @abstractmethod def load_state_dict(self, path): raise NotImplementedError + @abstractmethod def state_dict(self): raise NotImplementedError @@ -37,14 +40,14 @@ def state_dict(self): def reset(self): raise NotImplementedError - def self_influence_ranking(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + @cache_result + def self_influence(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: # Base class implements computing self influences by explaining the train dataset one by one - if self._self_influences is None: - self._self_influences = torch.empty((len(self.train_dataset),), device=self.device) - ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) - for i, (x, y) in enumerate(iter(ldr)): - upper_index = i * batch_size + x.shape[0] - explanations = self.explain(test=x.to(self.device), **kwargs) - explanations = explanations[:, i * batch_size : upper_index] - self._self_influences[i * batch_size : upper_index] = explanations.diag() - return self._self_influences.argsort() + influences = torch.empty((len(self.train_dataset),), device=self.device) + ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) + for i, (x, y) in enumerate(iter(ldr)): + upper_index = i * batch_size + x.shape[0] + explanations = self.explain(test=x.to(self.device), **kwargs) + explanations = explanations[:, i * batch_size : upper_index] + influences[i * batch_size : upper_index] = explanations.diag() + return influences.argsort() diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 6f1b3a9a..5888c891 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -2,8 +2,8 @@ import torch -from src.explainers.base import Explainer -from src.explainers.captum.similarity import CaptumSimilarityExplainer +from src.explainers.base_explainer import BaseExplainer +from src.explainers.wrappers.similarity import CaptumSimilarityExplainer class ExplainFunc(Protocol): @@ -23,7 +23,7 @@ def __call__( def explainer_functional_interface( - explainer_cls: Explainer, + explainer_cls: BaseExplainer, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], @@ -46,7 +46,7 @@ def explainer_functional_interface( def explainer_self_influence_interface( - explainer_cls: Explainer, + explainer_cls: BaseExplainer, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], diff --git a/src/explainers/wrappers/__init__.py b/src/explainers/wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/explainers/captum/base.py b/src/explainers/wrappers/base.py similarity index 92% rename from src/explainers/captum/base.py rename to src/explainers/wrappers/base.py index 6cc07ab0..38621b16 100644 --- a/src/explainers/captum/base.py +++ b/src/explainers/wrappers/base.py @@ -3,10 +3,10 @@ import torch from captum.influence import DataInfluence -from src.explainers.base import Explainer +from src.explainers.base_explainer import BaseExplainer -class CaptumExplainerWrapper(Explainer): +class CaptumExplainerWrapper(BaseExplainer): def __init__( self, model: torch.nn.Module, @@ -32,7 +32,7 @@ def initialize_captum(self, cls, **init_kwargs): self.captum_explainer = cls(model=self.model, train_dataset=self.train_dataset, **init_kwargs) def explain( - self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **explain_fn_kwargs + self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **explain_fn_kwargs ) -> torch.Tensor: test = test.to(self.device) if targets is not None: diff --git a/src/explainers/captum/similarity.py b/src/explainers/wrappers/similarity.py similarity index 93% rename from src/explainers/captum/similarity.py rename to src/explainers/wrappers/similarity.py index 0abc5358..b24731fe 100644 --- a/src/explainers/captum/similarity.py +++ b/src/explainers/wrappers/similarity.py @@ -3,7 +3,7 @@ import torch from captum.influence import SimilarityInfluence -from src.explainers.captum.base import CaptumExplainerWrapper +from src.explainers.wrappers.base import CaptumExplainerWrapper class CaptumSimilarityExplainer(CaptumExplainerWrapper): @@ -42,6 +42,9 @@ def initialize_captum(self, cls, **init_kwargs): def load_state_dict(self, path): return + def state_dict(self): + return + def reset(self): return diff --git a/src/utils/common.py b/src/utils/common.py index 8de51d1d..25a7fbfa 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -24,3 +24,15 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None = None, **kw _func_kwargs = kwargs return functools.partial(func, **_func_kwargs) + + +def cache_result(method): + cache_attr = f"_{method.__name__}_cache" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if cache_attr not in self.__dict__: + self.__dict__[cache_attr] = method(self, *args, **kwargs) + return self.__dict__[cache_attr] + + return wrapper diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py new file mode 100644 index 00000000..7ec30847 --- /dev/null +++ b/tests/explainers/test_base_explainer.py @@ -0,0 +1,44 @@ +import os +from typing import List, Optional, Union + +import pytest +import torch + +from explainers.base_explainer import BaseExplainer +from src.utils.functions.similarities import cosine_similarity + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + {"layer": "relu_4", "similarity_metric": cosine_similarity}, + ), + ], +) +def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, mocker, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + + BaseExplainer.__abstractmethods__ = set() + explainer = BaseExplainer( + model=model, + model_id="test_id", + cache_dir=os.path.join("./cache", "test_id"), + train_dataset=dataset, + device="cpu", + **method_kwargs, + ) + + # Patch the method + def mock_explain(test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None): + return torch.ones((test.shape[0], dataset.__len__())) + + mocker.patch.object(explainer, "explain", wraps=mock_explain) + + self_influence = explainer.self_influence() + assert self_influence.shape[0] == dataset.__len__(), "Self-influence shape does not match the dataset." diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index fd8be80a..24125c9f 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -3,8 +3,8 @@ import pytest import torch -from src.explainers.captum.similarity import CaptumSimilarityExplainer from src.explainers.functional import captum_similarity_explain +from src.explainers.wrappers.similarity import CaptumSimilarityExplainer from src.utils.functions.similarities import cosine_similarity @@ -18,7 +18,7 @@ "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"layer": "relu_4"}, + {"layer": "relu_4", "similarity_metric": cosine_similarity}, "load_mnist_explanations_1", ), ], @@ -71,5 +71,5 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te device="cpu", **method_kwargs, ) - explanations = explainer.explain(test_tensor, test_labels) + explanations = explainer.explain(test_tensor) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index 3faf9018..9a69571d 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -4,8 +4,8 @@ import torch from torch.utils.data import TensorDataset -from src.explainers.captum.similarity import CaptumSimilarityExplainer from src.explainers.functional import captum_similarity_self_influence_ranking +from src.explainers.wrappers.similarity import CaptumSimilarityExplainer from src.utils.functions.similarities import dot_product_similarity From bfe71357d5e10d9d0d76284bd89abbb701be9f63 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 11:33:52 +0200 Subject: [PATCH 18/30] fixing failing tests --- src/explainers/base_explainer.py | 2 +- src/explainers/functional.py | 2 +- src/utils/common.py | 4 ++-- tests/explainers/test_aggregators.py | 4 ++-- tests/explainers/test_base_explainer.py | 2 +- tests/explainers/test_self_influence.py | 9 +++------ tests/metrics/test_randomization_metrics.py | 4 ++-- 7 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/explainers/base_explainer.py b/src/explainers/base_explainer.py index 25c59624..73b98cbf 100644 --- a/src/explainers/base_explainer.py +++ b/src/explainers/base_explainer.py @@ -3,7 +3,7 @@ import torch -from utils.common import cache_result +from src.utils.common import cache_result class BaseExplainer(ABC): diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 5888c891..0033458e 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -62,7 +62,7 @@ def explainer_self_influence_interface( device=device, **init_kwargs, ) - return explainer.self_influence_ranking() + return explainer.self_influence() def captum_similarity_explain( diff --git a/src/utils/common.py b/src/utils/common.py index 25a7fbfa..b07849ac 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,6 +1,6 @@ import functools from functools import reduce -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, Optional import torch import torch.utils @@ -15,7 +15,7 @@ def _get_parent_module_from_name(model: torch.nn.Module, layer_name: str) -> Any return reduce(getattr, layer_name.split(".")[:-1], model) -def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None = None, **kwargs) -> functools.partial: +def make_func(func: Callable, func_kwargs: Optional[Mapping[str, Any]] = None, **kwargs) -> functools.partial: """A function for creating a partial function with the given arguments.""" if func_kwargs is not None: _func_kwargs = kwargs.copy() diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index f800c5c5..0be36b35 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -18,7 +18,7 @@ def test_sum_aggregator(test_id, dataset, explanations, request): dataset = request.getfixturevalue(dataset) explanations = request.getfixturevalue(explanations) - aggregator = SumAggregator(training_size=len(dataset)) + aggregator = SumAggregator() aggregator.update(explanations) global_rank = aggregator.compute() assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) @@ -38,7 +38,7 @@ def test_sum_aggregator(test_id, dataset, explanations, request): def test_abs_aggregator(test_id, dataset, explanations, request): dataset = request.getfixturevalue(dataset) explanations = request.getfixturevalue(explanations) - aggregator = AbsSumAggregator(training_size=len(dataset)) + aggregator = AbsSumAggregator() aggregator.update(explanations) global_rank = aggregator.compute() assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index 7ec30847..9cd6e1aa 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -4,7 +4,7 @@ import pytest import torch -from explainers.base_explainer import BaseExplainer +from src.explainers.base_explainer import BaseExplainer from src.utils.functions.similarities import cosine_similarity diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index 9a69571d..a6bcab8c 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -1,3 +1,5 @@ +import os +import shutil from collections import OrderedDict import pytest @@ -21,12 +23,7 @@ ) def test_self_influence(test_id, init_kwargs, request): model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) - # X=torch.randn(1,200) - import os - import shutil - os.mkdir("temp_captum") - os.mkdir("temp_captum2") torch.random.manual_seed(42) X = torch.randn(100, 200) # rand_dataset = TensorDataset(X,torch.randint(0,10,(1,))) @@ -46,7 +43,7 @@ def test_self_influence(test_id, init_kwargs, request): explainer_obj = CaptumSimilarityExplainer( model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", **init_kwargs ) - self_influence_rank_stateful = explainer_obj.self_influence_ranking() + self_influence_rank_stateful = explainer_obj.self_influence() if os.path.isdir("temp_captum2"): shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 7d7edb20..fc518063 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -5,6 +5,7 @@ from src.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) +from src.utils.functions.similarities import cosine_similarity @pytest.mark.randomization_metrics @@ -18,7 +19,7 @@ "load_mnist_test_samples_1", 8, { - "layer": "fc_2", + "layer": "fc_2", "similarity_metric": cosine_similarity, }, "load_mnist_explanations_1", "load_mnist_test_labels_1", @@ -31,7 +32,6 @@ def test_randomization_metric_functional( model = request.getfixturevalue(model) test_data = request.getfixturevalue(test_data) dataset = request.getfixturevalue(dataset) - explain_init_kwargs = request.getfixturevalue(explain_init_kwargs) test_labels = request.getfixturevalue(test_labels) tda = request.getfixturevalue(explanations) metric = ModelRandomizationMetric( From 8c83a1f64591dab807bc21a3b249a2f909f8d0d5 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 11:38:05 +0200 Subject: [PATCH 19/30] Revert "fixing explainer tests" This reverts commit 888aefd0eb826700883474666635919d183219f9. --- src/explainers/aggregators.py | 27 ++++-------- src/explainers/{base_explainer.py => base.py} | 35 +++++++-------- src/explainers/{wrappers => captum}/base.py | 6 +-- .../{wrappers => captum}/similarity.py | 5 +-- src/explainers/functional.py | 8 ++-- src/explainers/wrappers/__init__.py | 0 src/utils/common.py | 12 ----- tests/explainers/test_base_explainer.py | 44 ------------------- tests/explainers/test_explainers.py | 6 +-- tests/explainers/test_self_influence.py | 2 +- 10 files changed, 36 insertions(+), 109 deletions(-) rename src/explainers/{base_explainer.py => base.py} (53%) rename src/explainers/{wrappers => captum}/base.py (92%) rename src/explainers/{wrappers => captum}/similarity.py (93%) delete mode 100644 src/explainers/wrappers/__init__.py delete mode 100644 tests/explainers/test_base_explainer.py diff --git a/src/explainers/aggregators.py b/src/explainers/aggregators.py index 4d690f67..0d36b6b4 100644 --- a/src/explainers/aggregators.py +++ b/src/explainers/aggregators.py @@ -3,28 +3,19 @@ import torch -class BaseAggregator(ABC): - def __init__(self): - self.scores: torch.Tensor = None +class ExplanationsAggregator(ABC): + def __init__(self, training_size: int, *args, **kwargs): + self.scores = torch.zeros(training_size) @abstractmethod def update(self, explanations: torch.Tensor): raise NotImplementedError - def _validate_explanations(self, explanations: torch.Tensor): - if self.scores is None: - self.scores = torch.zeros(explanations.shape[1]) - - if explanations.shape[1] != self.scores.shape[0]: - raise ValueError( - f"Explanations shape {explanations.shape} does not match the expected shape {self.scores.shape}" - ) - def reset(self, *args, **kwargs): """ Used to reset the aggregator state. """ - self.scores: torch.Tensor = None + self.scores = torch.zeros_like(self.scores) def load_state_dict(self, state_dict: dict, *args, **kwargs): """ @@ -42,13 +33,11 @@ def compute(self) -> torch.Tensor: return self.scores.argsort() -class SumAggregator(BaseAggregator): - def update(self, explanations: torch.Tensor): - self._validate_explanations(explanations) +class SumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: self.scores += explanations.sum(dim=0) -class AbsSumAggregator(BaseAggregator): - def update(self, explanations: torch.Tensor): - self._validate_explanations(explanations) +class AbsSumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: self.scores += explanations.abs().sum(dim=0) diff --git a/src/explainers/base_explainer.py b/src/explainers/base.py similarity index 53% rename from src/explainers/base_explainer.py rename to src/explainers/base.py index 25c59624..510b8d0d 100644 --- a/src/explainers/base_explainer.py +++ b/src/explainers/base.py @@ -3,10 +3,8 @@ import torch -from utils.common import cache_result - -class BaseExplainer(ABC): +class Explainer(ABC): def __init__( self, model: torch.nn.Module, @@ -16,23 +14,22 @@ def __init__( device: Union[str, torch.device], **kwargs, ): + self.device = torch.device(device) if isinstance(device, str) else device + self.train_dataset = train_dataset + self._self_influences = None self.model = model - self.model.to(device) - + self.model.to(self.device) self.model_id = model_id self.cache_dir = cache_dir - self.train_dataset = train_dataset - self.device = torch.device(device) if isinstance(device, str) else device @abstractmethod - def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs): + def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **kwargs): raise NotImplementedError @abstractmethod def load_state_dict(self, path): raise NotImplementedError - @abstractmethod def state_dict(self): raise NotImplementedError @@ -40,14 +37,14 @@ def state_dict(self): def reset(self): raise NotImplementedError - @cache_result - def self_influence(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + def self_influence_ranking(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: # Base class implements computing self influences by explaining the train dataset one by one - influences = torch.empty((len(self.train_dataset),), device=self.device) - ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) - for i, (x, y) in enumerate(iter(ldr)): - upper_index = i * batch_size + x.shape[0] - explanations = self.explain(test=x.to(self.device), **kwargs) - explanations = explanations[:, i * batch_size : upper_index] - influences[i * batch_size : upper_index] = explanations.diag() - return influences.argsort() + if self._self_influences is None: + self._self_influences = torch.empty((len(self.train_dataset),), device=self.device) + ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) + for i, (x, y) in enumerate(iter(ldr)): + upper_index = i * batch_size + x.shape[0] + explanations = self.explain(test=x.to(self.device), **kwargs) + explanations = explanations[:, i * batch_size : upper_index] + self._self_influences[i * batch_size : upper_index] = explanations.diag() + return self._self_influences.argsort() diff --git a/src/explainers/wrappers/base.py b/src/explainers/captum/base.py similarity index 92% rename from src/explainers/wrappers/base.py rename to src/explainers/captum/base.py index 38621b16..6cc07ab0 100644 --- a/src/explainers/wrappers/base.py +++ b/src/explainers/captum/base.py @@ -3,10 +3,10 @@ import torch from captum.influence import DataInfluence -from src.explainers.base_explainer import BaseExplainer +from src.explainers.base import Explainer -class CaptumExplainerWrapper(BaseExplainer): +class CaptumExplainerWrapper(Explainer): def __init__( self, model: torch.nn.Module, @@ -32,7 +32,7 @@ def initialize_captum(self, cls, **init_kwargs): self.captum_explainer = cls(model=self.model, train_dataset=self.train_dataset, **init_kwargs) def explain( - self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **explain_fn_kwargs + self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]], **explain_fn_kwargs ) -> torch.Tensor: test = test.to(self.device) if targets is not None: diff --git a/src/explainers/wrappers/similarity.py b/src/explainers/captum/similarity.py similarity index 93% rename from src/explainers/wrappers/similarity.py rename to src/explainers/captum/similarity.py index b24731fe..0abc5358 100644 --- a/src/explainers/wrappers/similarity.py +++ b/src/explainers/captum/similarity.py @@ -3,7 +3,7 @@ import torch from captum.influence import SimilarityInfluence -from src.explainers.wrappers.base import CaptumExplainerWrapper +from src.explainers.captum.base import CaptumExplainerWrapper class CaptumSimilarityExplainer(CaptumExplainerWrapper): @@ -42,9 +42,6 @@ def initialize_captum(self, cls, **init_kwargs): def load_state_dict(self, path): return - def state_dict(self): - return - def reset(self): return diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 5888c891..6f1b3a9a 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -2,8 +2,8 @@ import torch -from src.explainers.base_explainer import BaseExplainer -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer +from src.explainers.base import Explainer +from src.explainers.captum.similarity import CaptumSimilarityExplainer class ExplainFunc(Protocol): @@ -23,7 +23,7 @@ def __call__( def explainer_functional_interface( - explainer_cls: BaseExplainer, + explainer_cls: Explainer, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], @@ -46,7 +46,7 @@ def explainer_functional_interface( def explainer_self_influence_interface( - explainer_cls: BaseExplainer, + explainer_cls: Explainer, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], diff --git a/src/explainers/wrappers/__init__.py b/src/explainers/wrappers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/common.py b/src/utils/common.py index 25a7fbfa..8de51d1d 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -24,15 +24,3 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None = None, **kw _func_kwargs = kwargs return functools.partial(func, **_func_kwargs) - - -def cache_result(method): - cache_attr = f"_{method.__name__}_cache" - - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - if cache_attr not in self.__dict__: - self.__dict__[cache_attr] = method(self, *args, **kwargs) - return self.__dict__[cache_attr] - - return wrapper diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py deleted file mode 100644 index 7ec30847..00000000 --- a/tests/explainers/test_base_explainer.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -from typing import List, Optional, Union - -import pytest -import torch - -from explainers.base_explainer import BaseExplainer -from src.utils.functions.similarities import cosine_similarity - - -@pytest.mark.explainers -@pytest.mark.parametrize( - "test_id, model, dataset, method_kwargs", - [ - ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - {"layer": "relu_4", "similarity_metric": cosine_similarity}, - ), - ], -) -def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, mocker, request): - model = request.getfixturevalue(model) - dataset = request.getfixturevalue(dataset) - - BaseExplainer.__abstractmethods__ = set() - explainer = BaseExplainer( - model=model, - model_id="test_id", - cache_dir=os.path.join("./cache", "test_id"), - train_dataset=dataset, - device="cpu", - **method_kwargs, - ) - - # Patch the method - def mock_explain(test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None): - return torch.ones((test.shape[0], dataset.__len__())) - - mocker.patch.object(explainer, "explain", wraps=mock_explain) - - self_influence = explainer.self_influence() - assert self_influence.shape[0] == dataset.__len__(), "Self-influence shape does not match the dataset." diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index 24125c9f..fd8be80a 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -3,8 +3,8 @@ import pytest import torch +from src.explainers.captum.similarity import CaptumSimilarityExplainer from src.explainers.functional import captum_similarity_explain -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer from src.utils.functions.similarities import cosine_similarity @@ -18,7 +18,7 @@ "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"layer": "relu_4", "similarity_metric": cosine_similarity}, + {"layer": "relu_4"}, "load_mnist_explanations_1", ), ], @@ -71,5 +71,5 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te device="cpu", **method_kwargs, ) - explanations = explainer.explain(test_tensor) + explanations = explainer.explain(test_tensor, test_labels) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index 9a69571d..3faf9018 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -4,8 +4,8 @@ import torch from torch.utils.data import TensorDataset +from src.explainers.captum.similarity import CaptumSimilarityExplainer from src.explainers.functional import captum_similarity_self_influence_ranking -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer from src.utils.functions.similarities import dot_product_similarity From 0a477f546303ad916b335a69d225729fa933a6f9 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 16:06:24 +0200 Subject: [PATCH 20/30] refactoring captum wrappers --- src/explainers/base_explainer.py | 18 +-- src/explainers/functional.py | 34 +++--- src/explainers/wrappers/base.py | 49 -------- src/explainers/wrappers/captum_influence.py | 120 ++++++++++++++++++++ src/explainers/wrappers/similarity.py | 55 --------- src/utils/validation.py | 19 ++++ tests/explainers/test_base_explainer.py | 2 +- tests/explainers/test_explainers.py | 10 +- tests/explainers/test_self_influence.py | 10 +- tests/metrics/test_randomization_metrics.py | 3 +- 10 files changed, 171 insertions(+), 149 deletions(-) delete mode 100644 src/explainers/wrappers/base.py create mode 100644 src/explainers/wrappers/captum_influence.py delete mode 100644 src/explainers/wrappers/similarity.py create mode 100644 src/utils/validation.py diff --git a/src/explainers/base_explainer.py b/src/explainers/base_explainer.py index 73b98cbf..abd7e41c 100644 --- a/src/explainers/base_explainer.py +++ b/src/explainers/base_explainer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import List, Optional, Union, Any import torch @@ -25,23 +25,11 @@ def __init__( self.device = torch.device(device) if isinstance(device, str) else device @abstractmethod - def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs): - raise NotImplementedError - - @abstractmethod - def load_state_dict(self, path): - raise NotImplementedError - - @abstractmethod - def state_dict(self): - raise NotImplementedError - - @abstractmethod - def reset(self): + def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any): raise NotImplementedError @cache_result - def self_influence(self, batch_size: Optional[int] = 32, **kwargs) -> torch.Tensor: + def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch.Tensor: # Base class implements computing self influences by explaining the train dataset one by one influences = torch.empty((len(self.train_dataset),), device=self.device) ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 0033458e..a9fd945d 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -3,7 +3,7 @@ import torch from src.explainers.base_explainer import BaseExplainer -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer +from src.explainers.wrappers.captum_influence import CaptumSimilarity class ExplainFunc(Protocol): @@ -22,8 +22,8 @@ def __call__( pass -def explainer_functional_interface( - explainer_cls: BaseExplainer, +def explain_fn_from_explainer( + explainer_cls: type, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], @@ -34,19 +34,19 @@ def explainer_functional_interface( init_kwargs: Optional[Dict] = {}, explain_kwargs: Optional[Dict] = {}, ) -> torch.Tensor: - explainer = explainer_cls( - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - **init_kwargs, - ) - return explainer.explain(test=test_tensor, **explain_kwargs) + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + explainer_kwargs=init_kwargs, + ) + return explainer.explain(test=test_tensor, **explain_kwargs) def explainer_self_influence_interface( - explainer_cls: BaseExplainer, + explainer_cls: type, model: torch.nn.Module, model_id: str, cache_dir: Optional[str], @@ -60,7 +60,7 @@ def explainer_self_influence_interface( cache_dir=cache_dir, train_dataset=train_dataset, device=device, - **init_kwargs, + explainer_kwargs=init_kwargs, ) return explainer.self_influence() @@ -76,8 +76,8 @@ def captum_similarity_explain( init_kwargs: Optional[Dict] = {}, explain_kwargs: Optional[Dict] = {}, ) -> torch.Tensor: - return explainer_functional_interface( - explainer_cls=CaptumSimilarityExplainer, + return explain_fn_from_explainer( + explainer_cls=CaptumSimilarity, model=model, model_id=model_id, cache_dir=cache_dir, @@ -99,7 +99,7 @@ def captum_similarity_self_influence_ranking( device: Union[str, torch.device], ) -> torch.Tensor: return explainer_self_influence_interface( - explainer_cls=CaptumSimilarityExplainer, + explainer_cls=CaptumSimilarity, model=model, model_id=model_id, cache_dir=cache_dir, diff --git a/src/explainers/wrappers/base.py b/src/explainers/wrappers/base.py deleted file mode 100644 index 38621b16..00000000 --- a/src/explainers/wrappers/base.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List, Optional, Union - -import torch -from captum.influence import DataInfluence - -from src.explainers.base_explainer import BaseExplainer - - -class CaptumExplainerWrapper(BaseExplainer): - def __init__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - explainer_cls: DataInfluence, - **explainer_init_kwargs, - ): - super().__init__( - model=model, model_id=model_id, train_dataset=train_dataset, device=device, cache_dir=cache_dir - ) - for shared_field_name in ["model_id", "cache_dir"]: - assert shared_field_name not in explainer_init_kwargs.keys(), ( - f"{shared_field_name} is already given to the explainer object, " - "it must not be repeated in the explainer_init_kwargs" - ) - - self.initialize_captum(explainer_cls, **explainer_init_kwargs) - - def initialize_captum(self, cls, **init_kwargs): - self.captum_explainer = cls(model=self.model, train_dataset=self.train_dataset, **init_kwargs) - - def explain( - self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **explain_fn_kwargs - ) -> torch.Tensor: - test = test.to(self.device) - if targets is not None: - if not isinstance(targets, torch.Tensor): - if isinstance(targets, list): - targets = torch.tensor(targets) - else: - raise TypeError( - f"targets should be of type NoneType, List or torch.Tensor. Got {type(targets)} instead." - ) - targets = targets.to(self.device) - return self.captum_explainer.influence(inputs=(test, targets), **explain_fn_kwargs) - else: - return self.captum_explainer.influence(inputs=test, **explain_fn_kwargs) diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py new file mode 100644 index 00000000..9e0948a9 --- /dev/null +++ b/src/explainers/wrappers/captum_influence.py @@ -0,0 +1,120 @@ +from typing import List, Optional, Union, Any, Dict + +import torch +from captum.influence import SimilarityInfluence + +from src.explainers.base_explainer import BaseExplainer +from src.utils.validation import validate_1d_tensor_or_int_list + + +class CaptumInfluence(BaseExplainer): + """ + TODO: should this inherit from BaseExplainer? + Or should it just follow the same protocol? + """ + + def __init__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + explainer_cls: type, + explain_kwargs: Dict[str, Any], + **kwargs, + ): + super().__init__( + model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, + ) + self.explainer_cls = explainer_cls + self.explain_kwargs = explain_kwargs + self._init_explainer(explainer_cls, **explain_kwargs) + + def _init_explainer(self, cls: type, **explain_kwargs: Any): + self.captum_explainer = cls(**explain_kwargs) + if not isinstance(self.captum_explainer, self.explainer_cls): + raise ValueError(f"Expected {self.explainer_cls}, but got {type(self.captum_explainer)}") + + def _process_targets(self, targets: Optional[Union[List[int], torch.Tensor]]): + if targets is not None: + # TODO: move validation logic outside at a later point + validate_1d_tensor_or_int_list(targets) + if isinstance(targets, list): + targets = torch.tensor(targets) + targets = targets.to(self.device) + return targets + + def explain( + self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **explain_fn_kwargs + ) -> torch.Tensor: + # Process inputs + test = test.to(self.device) + targets = self._process_targets(targets) + + if targets is not None: + return self.captum_explainer.influence(inputs=(test, targets), **explain_fn_kwargs) + else: + return self.captum_explainer.influence(inputs=test, **explain_fn_kwargs) + + +class CaptumSimilarity(CaptumInfluence): + def __init__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: str, + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + explainer_kwargs: Dict[str, Any], + ): + # extract and validate layer from kwargs + self._layer: Union[List[str], str] = None + self.layer = explainer_kwargs.get("layers", []) + + explainer_kwargs = { + "module": model, + "influence_src_dataset": train_dataset, + "activation_dir": cache_dir, + "model_id": model_id, + "similarity_direction": "max", + **explainer_kwargs, + } + + super().__init__( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + explainer_cls=SimilarityInfluence, + explain_kwargs=explainer_kwargs, + ) + + @property + def layer(self): + return self._layer + + @layer.setter + def layer(self, layers: Any): + """ + Our wrapper only allows a single layer to be passed, while the Captum implementation allows multiple layers. + Here, we validate if there is only a single layer passed. + """ + if isinstance(layers, str): + self._layer = layers + return + if len(layers) != 1: + raise ValueError("A single layer shall be passed to the CaptumSimilarity explainer.") + self._layer = layers[0] + + def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any): + # We might want to pass the top_k as an argument in some scenarios + top_k = kwargs.get("top_k", len(self.train_dataset)) + + topk_idx, topk_val = super().explain(test=test, top_k=top_k, **kwargs)[self.layer] + inverted_idx = topk_idx.argsort() + # Note to Galip: this is equivalent to your implementation (I checked the values) + tda = torch.gather(topk_val, 1, inverted_idx) + + return tda diff --git a/src/explainers/wrappers/similarity.py b/src/explainers/wrappers/similarity.py deleted file mode 100644 index b24731fe..00000000 --- a/src/explainers/wrappers/similarity.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Union - -import torch -from captum.influence import SimilarityInfluence - -from src.explainers.wrappers.base import CaptumExplainerWrapper - - -class CaptumSimilarityExplainer(CaptumExplainerWrapper): - def __init__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: str, - train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - layer: str, - **explainer_init_kwargs, - ): - self.layer = layer - super().__init__( - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - explainer_cls=SimilarityInfluence, - layers=[layer], - **explainer_init_kwargs, - ) - - def initialize_captum(self, cls, **init_kwargs): - self.captum_explainer = cls( - module=self.model, - influence_src_dataset=self.train_dataset, - activation_dir=self.cache_dir, - model_id=self.model_id, - similarity_direction="max", - **init_kwargs, - ) - - def load_state_dict(self, path): - return - - def state_dict(self): - return - - def reset(self): - return - - def explain(self, test: torch.Tensor) -> torch.Tensor: - topk_idx, topk_val = super().explain(test=test, targets=None, top_k=len(self.train_dataset))[self.layer] - inverted_idx = topk_idx.argsort() - tda = torch.cat([topk_val[None, i, inverted_idx[i]] for i in range(topk_idx.shape[0])], dim=0) - return tda diff --git a/src/utils/validation.py b/src/utils/validation.py new file mode 100644 index 00000000..345ee5e9 --- /dev/null +++ b/src/utils/validation.py @@ -0,0 +1,19 @@ +import torch + +""" +This is a Python module that contains helper functions for validating input arguments. +The plan is to collect them here and then create a universal validation decorator @validate_args +to check all the input arguments against the expected types specified e.g. +as class attributes. +""" + + +def validate_1d_tensor_or_int_list(targets): + if isinstance(targets, torch.Tensor): + if len(targets.shape) != 1: + raise ValueError(f"targets should be a 1D tensor. Got shape {targets.shape} instead.") + elif isinstance(targets, list): + if not all(isinstance(x, int) for x in targets): + raise ValueError(f"targets should be a list of integers. Got {targets} instead.") + else: + raise TypeError(f"targets should be of type List or torch.Tensor. Got {type(targets)} instead.") diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index 9cd6e1aa..113f0f7e 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -16,7 +16,7 @@ "mnist", "load_mnist_model", "load_mnist_dataset", - {"layer": "relu_4", "similarity_metric": cosine_similarity}, + {"layers": "relu_4", "similarity_metric": cosine_similarity}, ), ], ) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index 24125c9f..b2ad011a 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -4,7 +4,7 @@ import torch from src.explainers.functional import captum_similarity_explain -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer +from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.utils.functions.similarities import cosine_similarity @@ -18,7 +18,7 @@ "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"layer": "relu_4", "similarity_metric": cosine_similarity}, + {"layers": "relu_4", "similarity_metric": cosine_similarity}, "load_mnist_explanations_1", ), ], @@ -53,7 +53,7 @@ def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, m "load_mnist_explanations_1", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"layer": "relu_4", "similarity_metric": cosine_similarity}, + {"layers": "relu_4", "similarity_metric": cosine_similarity}, ), ], ) @@ -63,13 +63,13 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te test_tensor = request.getfixturevalue(test_tensor) test_labels = request.getfixturevalue(test_labels) explanations_exp = request.getfixturevalue(explanations) - explainer = CaptumSimilarityExplainer( + explainer = CaptumSimilarity( model=model, model_id="test_id", cache_dir=os.path.join("./cache", "test_id"), train_dataset=dataset, device="cpu", - **method_kwargs, + explainer_kwargs=method_kwargs, ) explanations = explainer.explain(test_tensor) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index a6bcab8c..e24dfa2c 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -7,7 +7,7 @@ from torch.utils.data import TensorDataset from src.explainers.functional import captum_similarity_self_influence_ranking -from src.explainers.wrappers.similarity import CaptumSimilarityExplainer +from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.utils.functions.similarities import dot_product_similarity @@ -17,7 +17,7 @@ [ ( "random_data", - {"layer": "identity", "similarity_metric": dot_product_similarity}, + {"layers": "identity", "similarity_metric": dot_product_similarity}, ), ], ) @@ -26,10 +26,8 @@ def test_self_influence(test_id, init_kwargs, request): torch.random.manual_seed(42) X = torch.randn(100, 200) - # rand_dataset = TensorDataset(X,torch.randint(0,10,(1,))) y = torch.randint(0, 10, (100,)) rand_dataset = TensorDataset(X, y) - init_kwargs = {"layer": "identity", "similarity_metric": dot_product_similarity} self_influence_rank_functional = captum_similarity_self_influence_ranking( model=model, @@ -40,8 +38,8 @@ def test_self_influence(test_id, init_kwargs, request): device="cpu", ) - explainer_obj = CaptumSimilarityExplainer( - model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", **init_kwargs + explainer_obj = CaptumSimilarity( + model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", explainer_kwargs=init_kwargs ) self_influence_rank_stateful = explainer_obj.self_influence() diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index fc518063..1ac6137b 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -19,7 +19,8 @@ "load_mnist_test_samples_1", 8, { - "layer": "fc_2", "similarity_metric": cosine_similarity, + "layers": "fc_2", + "similarity_metric": cosine_similarity, }, "load_mnist_explanations_1", "load_mnist_test_labels_1", From 950b8b86d18dd8135dd737af438cf1f9d6114929 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 17:13:50 +0200 Subject: [PATCH 21/30] make kwargs **kwargs again --- src/explainers/base_explainer.py | 2 +- src/explainers/functional.py | 21 ++++++------ src/explainers/wrappers/captum_influence.py | 37 ++++++++++++--------- src/utils/validation.py | 6 ++-- tests/explainers/test_explainers.py | 2 +- tests/explainers/test_self_influence.py | 7 +++- 6 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/explainers/base_explainer.py b/src/explainers/base_explainer.py index abd7e41c..d0bb5796 100644 --- a/src/explainers/base_explainer.py +++ b/src/explainers/base_explainer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Union, Any +from typing import Any, List, Optional, Union import torch diff --git a/src/explainers/functional.py b/src/explainers/functional.py index a9fd945d..528a1f51 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -2,7 +2,6 @@ import torch -from src.explainers.base_explainer import BaseExplainer from src.explainers.wrappers.captum_influence import CaptumSimilarity @@ -34,15 +33,15 @@ def explain_fn_from_explainer( init_kwargs: Optional[Dict] = {}, explain_kwargs: Optional[Dict] = {}, ) -> torch.Tensor: - explainer = explainer_cls( - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - explainer_kwargs=init_kwargs, - ) - return explainer.explain(test=test_tensor, **explain_kwargs) + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer.explain(test=test_tensor, **explain_kwargs) def explainer_self_influence_interface( @@ -60,7 +59,7 @@ def explainer_self_influence_interface( cache_dir=cache_dir, train_dataset=train_dataset, device=device, - explainer_kwargs=init_kwargs, + **init_kwargs, ) return explainer.self_influence() diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index 9e0948a9..9b3699e5 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Any, Dict +from typing import Any, List, Optional, Union import torch from captum.influence import SimilarityInfluence @@ -21,11 +21,14 @@ def __init__( train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], explainer_cls: type, - explain_kwargs: Dict[str, Any], - **kwargs, + **explain_kwargs: Any, ): super().__init__( - model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, ) self.explainer_cls = explainer_cls self.explain_kwargs = explain_kwargs @@ -66,29 +69,31 @@ def __init__( cache_dir: str, train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], - explainer_kwargs: Dict[str, Any], + **explainer_kwargs: Any, ): # extract and validate layer from kwargs - self._layer: Union[List[str], str] = None + self._layer: Optional[Union[List[str], str]] = None self.layer = explainer_kwargs.get("layers", []) - explainer_kwargs = { - "module": model, - "influence_src_dataset": train_dataset, - "activation_dir": cache_dir, - "model_id": model_id, - "similarity_direction": "max", - **explainer_kwargs, - } + # TODO: validate SimilarityInfluence kwargs + explainer_kwargs.update( + { + "module": model, + "influence_src_dataset": train_dataset, + "activation_dir": cache_dir, + "model_id": model_id, + "similarity_direction": "max", + **explainer_kwargs, + } + ) super().__init__( model=model, - model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, device=device, explainer_cls=SimilarityInfluence, - explain_kwargs=explainer_kwargs, + **explainer_kwargs, ) @property diff --git a/src/utils/validation.py b/src/utils/validation.py index 345ee5e9..330070a0 100644 --- a/src/utils/validation.py +++ b/src/utils/validation.py @@ -1,10 +1,8 @@ import torch """ -This is a Python module that contains helper functions for validating input arguments. -The plan is to collect them here and then create a universal validation decorator @validate_args -to check all the input arguments against the expected types specified e.g. -as class attributes. +This module contains utility functions for validation. The plan is to +move the validation logic into a validation decorator at a later point. """ diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index b2ad011a..78ee3e71 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -69,7 +69,7 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te cache_dir=os.path.join("./cache", "test_id"), train_dataset=dataset, device="cpu", - explainer_kwargs=method_kwargs, + **method_kwargs, ) explanations = explainer.explain(test_tensor) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index e24dfa2c..d2487d58 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -39,7 +39,12 @@ def test_self_influence(test_id, init_kwargs, request): ) explainer_obj = CaptumSimilarity( - model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", explainer_kwargs=init_kwargs + model=model, + model_id="1", + cache_dir="temp_captum2", + train_dataset=rand_dataset, + device="cpu", + **init_kwargs, ) self_influence_rank_stateful = explainer_obj.self_influence() From 5cded75a771142a857895e028b93ae07c263b71d Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 17:37:36 +0200 Subject: [PATCH 22/30] sort explain fns --- src/explainers/aggregators.py | 3 +- src/explainers/functional.py | 89 --------------------- src/explainers/utils.py | 56 +++++++++++++ src/explainers/wrappers/captum_influence.py | 57 ++++++++++++- tests/explainers/test_explainers.py | 6 +- tests/explainers/test_self_influence.py | 8 +- tests/metrics/test_randomization_metrics.py | 2 +- 7 files changed, 124 insertions(+), 97 deletions(-) create mode 100644 src/explainers/utils.py diff --git a/src/explainers/aggregators.py b/src/explainers/aggregators.py index 4d690f67..3a3ead5e 100644 --- a/src/explainers/aggregators.py +++ b/src/explainers/aggregators.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod +from typing import Optional import torch class BaseAggregator(ABC): def __init__(self): - self.scores: torch.Tensor = None + self.scores: Optional[torch.Tensor] = None @abstractmethod def update(self, explanations: torch.Tensor): diff --git a/src/explainers/functional.py b/src/explainers/functional.py index 528a1f51..8e7b33b8 100644 --- a/src/explainers/functional.py +++ b/src/explainers/functional.py @@ -2,8 +2,6 @@ import torch -from src.explainers.wrappers.captum_influence import CaptumSimilarity - class ExplainFunc(Protocol): def __call__( @@ -19,90 +17,3 @@ def __call__( device: Union[str, torch.device], ) -> torch.Tensor: pass - - -def explain_fn_from_explainer( - explainer_cls: type, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - test_tensor: torch.Tensor, - explanation_targets: Optional[Union[List[int], torch.Tensor]], - train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - init_kwargs: Optional[Dict] = {}, - explain_kwargs: Optional[Dict] = {}, -) -> torch.Tensor: - explainer = explainer_cls( - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - **init_kwargs, - ) - return explainer.explain(test=test_tensor, **explain_kwargs) - - -def explainer_self_influence_interface( - explainer_cls: type, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - init_kwargs: Dict, - device: Union[str, torch.device], -) -> torch.Tensor: - explainer = explainer_cls( - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - **init_kwargs, - ) - return explainer.self_influence() - - -def captum_similarity_explain( - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - test_tensor: torch.Tensor, - explanation_targets: Optional[Union[List[int], torch.Tensor]], - train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - init_kwargs: Optional[Dict] = {}, - explain_kwargs: Optional[Dict] = {}, -) -> torch.Tensor: - return explain_fn_from_explainer( - explainer_cls=CaptumSimilarity, - model=model, - model_id=model_id, - cache_dir=cache_dir, - test_tensor=test_tensor, - explanation_targets=explanation_targets, - train_dataset=train_dataset, - device=device, - init_kwargs=init_kwargs, - explain_kwargs=explain_kwargs, - ) - - -def captum_similarity_self_influence_ranking( - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - init_kwargs: Dict, - device: Union[str, torch.device], -) -> torch.Tensor: - return explainer_self_influence_interface( - explainer_cls=CaptumSimilarity, - model=model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - device=device, - init_kwargs=init_kwargs, - ) diff --git a/src/explainers/utils.py b/src/explainers/utils.py new file mode 100644 index 00000000..f7e121fc --- /dev/null +++ b/src/explainers/utils.py @@ -0,0 +1,56 @@ +from typing import Dict, List, Optional, Union + +import torch + + +def explain_fn_from_explainer( + explainer_cls: type, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + targets: Optional[Union[List[int], torch.Tensor]] = None, + init_kwargs: Optional[Dict] = None, + explain_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + + init_kwargs = init_kwargs or {} + explain_kwargs = explain_kwargs or {} + + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer.explain(test=test_tensor, targets=targets, **explain_kwargs) + + +def self_influence_fn_from_explainer( + explainer_cls: type, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + batch_size: Optional[int] = 32, + init_kwargs: Optional[Dict] = None, + explain_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + + init_kwargs = init_kwargs or {} + explain_kwargs = explain_kwargs or {} + + explainer = explainer_cls( + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + **init_kwargs, + ) + return explainer.self_influence(batch_size=batch_size, **explain_kwargs) diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index 9b3699e5..ebcb76b5 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -1,9 +1,13 @@ -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from captum.influence import SimilarityInfluence from src.explainers.base_explainer import BaseExplainer +from src.explainers.utils import ( + explain_fn_from_explainer, + self_influence_fn_from_explainer, +) from src.utils.validation import validate_1d_tensor_or_int_list @@ -123,3 +127,54 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T tda = torch.gather(topk_val, 1, inverted_idx) return tda + + +def captum_similarity_explain( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + explanation_targets: Optional[Union[List[int], torch.Tensor]], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + init_kwargs: Optional[Dict] = None, + explain_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + + init_kwargs = init_kwargs or {} + explain_kwargs = explain_kwargs or {} + + return explain_fn_from_explainer( + explainer_cls=CaptumSimilarity, + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=test_tensor, + targets=explanation_targets, + train_dataset=train_dataset, + device=device, + init_kwargs=init_kwargs, + explain_kwargs=explain_kwargs, + ) + + +def captum_similarity_self_influence( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + init_kwargs: Dict, + device: Union[str, torch.device], +) -> torch.Tensor: + + init_kwargs = init_kwargs or {} + + return self_influence_fn_from_explainer( + explainer_cls=CaptumSimilarity, + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + init_kwargs=init_kwargs, + ) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py index 78ee3e71..3fb3805b 100644 --- a/tests/explainers/test_explainers.py +++ b/tests/explainers/test_explainers.py @@ -3,8 +3,10 @@ import pytest import torch -from src.explainers.functional import captum_similarity_explain -from src.explainers.wrappers.captum_influence import CaptumSimilarity +from src.explainers.wrappers.captum_influence import ( + CaptumSimilarity, + captum_similarity_explain, +) from src.utils.functions.similarities import cosine_similarity diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py index d2487d58..931f3109 100644 --- a/tests/explainers/test_self_influence.py +++ b/tests/explainers/test_self_influence.py @@ -6,8 +6,10 @@ import torch from torch.utils.data import TensorDataset -from src.explainers.functional import captum_similarity_self_influence_ranking -from src.explainers.wrappers.captum_influence import CaptumSimilarity +from src.explainers.wrappers.captum_influence import ( + CaptumSimilarity, + captum_similarity_self_influence, +) from src.utils.functions.similarities import dot_product_similarity @@ -29,7 +31,7 @@ def test_self_influence(test_id, init_kwargs, request): y = torch.randint(0, 10, (100,)) rand_dataset = TensorDataset(X, y) - self_influence_rank_functional = captum_similarity_self_influence_ranking( + self_influence_rank_functional = captum_similarity_self_influence( model=model, model_id="0", cache_dir="temp_captum", diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 1ac6137b..e65429a1 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,7 +1,7 @@ import pytest import torch -from src.explainers.functional import captum_similarity_explain +from src.explainers.wrappers.captum_influence import captum_similarity_explain from src.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) From ad2834cde63029780a0b0f34a57821deb532e555 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Thu, 20 Jun 2024 23:15:40 +0200 Subject: [PATCH 23/30] fixing some tests --- Makefile | 2 +- pyproject.toml | 1 + src/metrics/base.py | 2 +- src/metrics/functional.py | 50 ------ src/metrics/localization/identical_class.py | 8 +- .../randomization/model_randomization.py | 39 ++--- src/utils/datasets/utils.py | 2 +- tests/conftest.py | 24 +-- tests/explainers/test_aggregators.py | 12 +- tests/explainers/test_base_explainer.py | 10 +- tests/explainers/test_explainers.py | 77 --------- tests/explainers/test_self_influence.py | 59 ------- tests/explainers/wrappers/__init__.py | 0 .../wrappers/test_captum_influence.py | 151 ++++++++++++++++++ tests/metrics/test_localization_metrics.py | 2 +- tests/metrics/test_randomization_metrics.py | 5 +- tox.ini | 1 + 17 files changed, 206 insertions(+), 239 deletions(-) delete mode 100644 src/metrics/functional.py delete mode 100644 tests/explainers/test_explainers.py delete mode 100644 tests/explainers/test_self_influence.py create mode 100644 tests/explainers/wrappers/__init__.py create mode 100644 tests/explainers/wrappers/test_captum_influence.py diff --git a/Makefile b/Makefile index 56b633f1..09b0fda5 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ SHELL = /bin/bash .PHONY: style style: black . - flake8 . + flake8 . --pytest-parametrize-names-type=csv python -m isort . rm -f .coverage rm -f .coverage.* diff --git a/pyproject.toml b/pyproject.toml index f2c74a04..605f6a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh "coverage>=7.2.3", "flake8>=6.0.0", "pytest<=7.4.4", + "flake8-pytest-style>=1.3.2", "pytest-cov>=4.0.0", "pytest-mock==3.10.0", "pre-commit>=3.2.0", diff --git a/src/metrics/base.py b/src/metrics/base.py index 13d13552..61eb6bd9 100644 --- a/src/metrics/base.py +++ b/src/metrics/base.py @@ -40,7 +40,7 @@ def reset(self, *args, **kwargs): raise NotImplementedError @abstractmethod - def load_state_dict(self, state_dict: dict, *args, **kwargs): + def load_state_dict(self, state_dict: dict): """ Used to load the metric state. """ diff --git a/src/metrics/functional.py b/src/metrics/functional.py deleted file mode 100644 index c6a30995..00000000 --- a/src/metrics/functional.py +++ /dev/null @@ -1,50 +0,0 @@ -""" - -WORK IN PROGRESS!!! -""" - -import warnings -from typing import Optional, Union - -import torch - -from src.utils.cache import ExplanationsCache as EC -from src.utils.explanations import ( - BatchedCachedExplanations, - TensorExplanations, -) - - -def function_example( - model: torch.nn.Module, - train_dataset: torch.utils.data.Dataset, - top_k: int = 1, - explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", - batch_size: Optional[int] = 8, - device="cpu", - **kwargs, -): - """ - I've copied the existing code from the memory-less metric version here, that can be reused in the future here. - It will not be called "function_example" in the future. There will be many reusable functions, but every metric - will get a functional version here. - - :param model: - :param train_dataset: - :param top_k: - :param explanations: - :param batch_size: - :param device: - :param kwargs: - :return: - """ - if isinstance(explanations, str): - explanations = EC.load(path=explanations, device=device) - if explanations.batch_size != batch_size: - warnings.warn( - "Batch size mismatch between loaded explanations and passed batch size. The inferred batch " - "size will be used instead." - ) - batch_size = explanations[0] - elif isinstance(explanations, torch.Tensor): - explanations = TensorExplanations(explanations, batch_size=batch_size, device=device) diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 63e06638..fbd14034 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -8,11 +8,11 @@ def __init__( self, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - device, + device: str, *args, **kwargs, ): - super().__init__(model, train_dataset, device, *args, **kwargs) + super().__init__(model=model, train_dataset=train_dataset, device=device, *args, **kwargs) self.scores = [] def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): @@ -27,8 +27,8 @@ def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): top_one_xpl_indices = explanations.argmax(dim=1) top_one_xpl_targets = torch.stack([self.train_dataset[i][1] for i in top_one_xpl_indices]) - score = (test_labels == top_one_xpl_targets) * 1.0 - self.scores.append(score) + scores = (test_labels == top_one_xpl_targets) * 1.0 + self.scores.append(scores) def compute(self): """ diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index 90d5a669..e8070f8e 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -64,16 +64,17 @@ def __init__( train_dataset=self.train_dataset, ) - self.results = {"rank_correlations": []} + self.results = {"scores": []} + # TODO: create a validation utility function if isinstance(correlation_fn, str) and correlation_fn in correlation_functions: - self.correlation_measure = correlation_functions.get(correlation_fn) + self.corr_measure = correlation_functions.get(correlation_fn) elif callable(correlation_fn): - self.correlation_measure = correlation_fn + self.corr_measure = correlation_fn else: raise ValueError( f"Invalid correlation function: expected one of {list(correlation_functions.keys())} or" - f"a Callable, but got {self.correlation_measure}." + f"a Callable, but got {self.corr_measure}." ) def update( @@ -82,37 +83,39 @@ def update( explanations: torch.Tensor, explanation_targets: torch.Tensor, ): - device = "cuda" if torch.cuda.is_available() else "cpu" rand_explanations = self.explain_fn( - model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=device + model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device ) - corrs = self.correlation_measure(explanations, rand_explanations) - self.results["rank_correlations"].append(corrs) + corrs = self.corr_measure(explanations, rand_explanations) + self.results["scores"].append(corrs) def compute(self): - return torch.cat(self.results["rank_correlations"]).mean() + return torch.cat(self.results["scores"]).mean() def reset(self): - self.results = {"rank_correlations": []} + self.results = {"scores": []} self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(self.model) def state_dict(self): state_dict = { "results_dict": self.results, - "random_model_state_dict": self.model.state_dict(), - "seed": self.seed, - "generator_state": self.generator.get_state(), - "explain_fn": self.explain_fn, + "rnd_model": self.model.state_dict(), + # Note to Galip: I suggest removing this, because those are explicitly passed + # as init arguments and this is an unexpected side effect if we overwrite them. + # Plus, we only ever use seed to randomize the model once. + # "seed": self.seed, + # "generator_state": self.generator.get_state(), + # "explain_fn": self.explain_fn, } return state_dict def load_state_dict(self, state_dict: dict): self.results = state_dict["results_dict"] - self.seed = state_dict["seed"] - self.explain_fn = state_dict["explain_fn"] - self.rand_model.load_state_dict(state_dict["random_model_state_dict"]) - self.generator.set_state(state_dict["generator_state"]) + self.rand_model.load_state_dict(state_dict["rnd_model"]) + # self.seed = state_dict["seed"] + # self.explain_fn = state_dict["explain_fn"] + # self.generator.set_state(state_dict["generator_state"]) def _randomize_model(self, model: torch.nn.Module): rand_model = copy.deepcopy(model) diff --git a/src/utils/datasets/utils.py b/src/utils/datasets/utils.py index 6f9102fb..4043adc4 100644 --- a/src/utils/datasets/utils.py +++ b/src/utils/datasets/utils.py @@ -32,7 +32,7 @@ def load_datasets(dataset_name, dataset_type, **kwparams): elif dataset_type == "mark": ds = MarkDataset(ds, only_train=only_train) evalds = MarkDataset(evalds, only_train=only_train) - assert ds is not None and evalds is not None + # assert ds is not None and evalds is not None return ds, evalds diff --git a/tests/conftest.py b/tests/conftest.py index 349c79f5..7f95f81c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,24 +14,24 @@ RANDOM_SEED = 42 -@pytest.fixture() +@pytest.fixture def load_dataset(): x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)]) y = torch.tensor([0, 1, 0]).long() return torch.utils.data.TensorDataset(x, y) -@pytest.fixture() +@pytest.fixture def load_rand_tensor(): return torch.rand(10, 10).float() -@pytest.fixture() +@pytest.fixture def load_rand_test_predictions(): return torch.randint(0, 10, (10000,)) -@pytest.fixture() +@pytest.fixture def load_mnist_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNet() @@ -39,13 +39,13 @@ def load_mnist_model(): return model -@pytest.fixture() +@pytest.fixture def load_init_mnist_model(): """Load a not trained LeNet classification model (architecture at quantus/helpers/models).""" return LeNet() -@pytest.fixture() +@pytest.fixture def load_mnist_dataset(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" x_batch = ( @@ -58,7 +58,7 @@ def load_mnist_dataset(): return dataset -@pytest.fixture() +@pytest.fixture def load_mnist_dataloader(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" x_batch = ( @@ -72,26 +72,26 @@ def load_mnist_dataloader(): return dataloader -@pytest.fixture() +@pytest.fixture def load_mnist_test_samples_1(): return torch.load("tests/assets/mnist_test_suite_1/test_dataset.pt") -@pytest.fixture() +@pytest.fixture def load_mnist_test_labels_1(): return torch.load("tests/assets/mnist_test_suite_1/test_labels.pt") -@pytest.fixture() +@pytest.fixture def load_mnist_explanations_1(): return torch.load("tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") -@pytest.fixture() +@pytest.fixture def torch_cross_entropy_loss_object(): return torch.nn.CrossEntropyLoss() -@pytest.fixture() +@pytest.fixture def torch_sgd_optimizer(): return functools.partial(torch.optim.SGD, lr=0.01, momentum=0.9) diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index 0be36b35..d166ca27 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -6,17 +6,15 @@ @pytest.mark.aggregators @pytest.mark.parametrize( - "test_id, dataset, explanations", + "test_id, explanations", [ ( "mnist", - "load_mnist_dataset", "load_mnist_explanations_1", ), ], ) -def test_sum_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) +def test_sum_aggregator(test_id, explanations, request): explanations = request.getfixturevalue(explanations) aggregator = SumAggregator() aggregator.update(explanations) @@ -26,17 +24,15 @@ def test_sum_aggregator(test_id, dataset, explanations, request): @pytest.mark.aggregators @pytest.mark.parametrize( - "test_id, dataset, explanations", + "test_id, explanations", [ ( "mnist", - "load_mnist_dataset", "load_mnist_explanations_1", ), ], ) -def test_abs_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) +def test_abs_aggregator(test_id, explanations, request): explanations = request.getfixturevalue(explanations) aggregator = AbsSumAggregator() aggregator.update(explanations) diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index 113f0f7e..d4978c06 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -10,19 +10,21 @@ @pytest.mark.explainers @pytest.mark.parametrize( - "test_id, model, dataset, method_kwargs", + "test_id, model, dataset, explanations, method_kwargs", [ ( "mnist", "load_mnist_model", "load_mnist_dataset", + "load_mnist_explanations_1", {"layers": "relu_4", "similarity_metric": cosine_similarity}, ), ], ) -def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, mocker, request): +def test_base_explain_self_influence(test_id, model, dataset, explanations, method_kwargs, mocker, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) + explanations = request.getfixturevalue(explanations) BaseExplainer.__abstractmethods__ = set() explainer = BaseExplainer( @@ -34,9 +36,9 @@ def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, moc **method_kwargs, ) - # Patch the method + # Patch the method, because BaseExplainer has an abstract explain method. def mock_explain(test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None): - return torch.ones((test.shape[0], dataset.__len__())) + return explanations mocker.patch.object(explainer, "explain", wraps=mock_explain) diff --git a/tests/explainers/test_explainers.py b/tests/explainers/test_explainers.py deleted file mode 100644 index 3fb3805b..00000000 --- a/tests/explainers/test_explainers.py +++ /dev/null @@ -1,77 +0,0 @@ -import os - -import pytest -import torch - -from src.explainers.wrappers.captum_influence import ( - CaptumSimilarity, - captum_similarity_explain, -) -from src.utils.functions.similarities import cosine_similarity - - -@pytest.mark.explainers -@pytest.mark.parametrize( - "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", - [ - ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - "load_mnist_test_samples_1", - "load_mnist_test_labels_1", - {"layers": "relu_4", "similarity_metric": cosine_similarity}, - "load_mnist_explanations_1", - ), - ], -) -def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request): - model = request.getfixturevalue(model) - dataset = request.getfixturevalue(dataset) - test_tensor = request.getfixturevalue(test_tensor) - test_labels = request.getfixturevalue(test_labels) - explanations_exp = request.getfixturevalue(explanations) - explanations = captum_similarity_explain( - model, - "test_id", - os.path.join("./cache", "test_id"), - test_tensor, - test_labels, - dataset, - device="cpu", - init_kwargs=method_kwargs, - ) - assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" - - -@pytest.mark.explainers -@pytest.mark.parametrize( - "test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs", - [ - ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - "load_mnist_explanations_1", - "load_mnist_test_samples_1", - "load_mnist_test_labels_1", - {"layers": "relu_4", "similarity_metric": cosine_similarity}, - ), - ], -) -def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): - model = request.getfixturevalue(model) - dataset = request.getfixturevalue(dataset) - test_tensor = request.getfixturevalue(test_tensor) - test_labels = request.getfixturevalue(test_labels) - explanations_exp = request.getfixturevalue(explanations) - explainer = CaptumSimilarity( - model=model, - model_id="test_id", - cache_dir=os.path.join("./cache", "test_id"), - train_dataset=dataset, - device="cpu", - **method_kwargs, - ) - explanations = explainer.explain(test_tensor) - assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/explainers/test_self_influence.py b/tests/explainers/test_self_influence.py deleted file mode 100644 index 931f3109..00000000 --- a/tests/explainers/test_self_influence.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from collections import OrderedDict - -import pytest -import torch -from torch.utils.data import TensorDataset - -from src.explainers.wrappers.captum_influence import ( - CaptumSimilarity, - captum_similarity_self_influence, -) -from src.utils.functions.similarities import dot_product_similarity - - -@pytest.mark.self_influence -@pytest.mark.parametrize( - "test_id, init_kwargs", - [ - ( - "random_data", - {"layers": "identity", "similarity_metric": dot_product_similarity}, - ), - ], -) -def test_self_influence(test_id, init_kwargs, request): - model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) - - torch.random.manual_seed(42) - X = torch.randn(100, 200) - y = torch.randint(0, 10, (100,)) - rand_dataset = TensorDataset(X, y) - - self_influence_rank_functional = captum_similarity_self_influence( - model=model, - model_id="0", - cache_dir="temp_captum", - train_dataset=rand_dataset, - init_kwargs=init_kwargs, - device="cpu", - ) - - explainer_obj = CaptumSimilarity( - model=model, - model_id="1", - cache_dir="temp_captum2", - train_dataset=rand_dataset, - device="cpu", - **init_kwargs, - ) - self_influence_rank_stateful = explainer_obj.self_influence() - - if os.path.isdir("temp_captum2"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) - if os.path.isdir("temp_captum"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum")) - - assert torch.allclose(self_influence_rank_functional, torch.linalg.norm(X, dim=-1).argsort()) - assert torch.allclose(self_influence_rank_functional, self_influence_rank_stateful) diff --git a/tests/explainers/wrappers/__init__.py b/tests/explainers/wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py new file mode 100644 index 00000000..e88c353d --- /dev/null +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -0,0 +1,151 @@ +import os +import shutil +from collections import OrderedDict + +import pytest +import torch +from torch.utils.data import TensorDataset + +from src.explainers.wrappers.captum_influence import ( + CaptumSimilarity, + captum_similarity_explain, + captum_similarity_self_influence, +) +from src.utils.functions.similarities import ( + cosine_similarity, + dot_product_similarity, +) + + +@pytest.mark.self_influence +@pytest.mark.parametrize( + "test_id, init_kwargs", + [ + ( + "random_data", + {"layers": "identity", "similarity_metric": dot_product_similarity}, + ), + ], +) +# TODO: I think a good naming convention is "test_..." or "test_...". +def test_self_influence(test_id, init_kwargs, request): + # TODO: this should be a fixture. + model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) + + # TODO: those should be fixtures. We (most of the time) don't generate random data in tests. + torch.random.manual_seed(42) + X = torch.randn(100, 200) + y = torch.randint(0, 10, (100,)) + rand_dataset = TensorDataset(X, y) + + # TODO: One test should test one thing. This is test 1, .... + self_influence_rank_functional = captum_similarity_self_influence( + model=model, + model_id="0", + cache_dir="temp_captum", + train_dataset=rand_dataset, + init_kwargs=init_kwargs, + device="cpu", + ) + + # TODO: ...this is test 2, unless we want to compare that the outputs are the same. + # TODO: If we want to test that the outputs are the same, we should have a separate test for that. + explainer_obj = CaptumSimilarity( + model=model, + model_id="1", + cache_dir="temp_captum2", + train_dataset=rand_dataset, + device="cpu", + **init_kwargs, + ) + + # TODO: self_influence is defined in BaseExplainer - there is a test in test_base_explainer for that. + # TODO: here we then specifically test self_influence for CaptumSimilarity and should make it explicit in the name. + self_influence_rank_stateful = explainer_obj.self_influence() + + # TODO: we check "temp_captum2" but then remove os.path.join(os.getcwd(), "temp_captum2")? + # TODO: is there a reason to fear that the "temp_captum2" folder is not in os.getcwd()? + if os.path.isdir("temp_captum2"): + shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) + if os.path.isdir("temp_captum"): + shutil.rmtree(os.path.join(os.getcwd(), "temp_captum")) + + # TODO: what if we pass a non-identity model? Then we don't expect torch.linalg.norm(X, dim=-1).argsort() + # TODO: let's put expectations in the parametrisation of tests. We want to test different scenarios, + # and not some super-specific case. This specific case definitely can be tested as well. + assert torch.allclose(self_influence_rank_functional, torch.linalg.norm(X, dim=-1).argsort()) + # TODO: I think it is best to stick to a single assertion per test (source: Google) + assert torch.allclose(self_influence_rank_functional, self_influence_rank_stateful) + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_explanations_1", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"layers": "relu_4", "similarity_metric": cosine_similarity}, + ), + ], +) +# TODO: I think a good naming convention is "test_..." or "test_...". +# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. +# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. +def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + + explainer = CaptumSimilarity( + model=model, + model_id="test_id", + cache_dir=os.path.join("./cache", "test_id"), + train_dataset=dataset, + device="cpu", + **method_kwargs, + ) + # TODO: activations folder clean-up + + explanations = explainer.explain(test_tensor) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"layers": "relu_4", "similarity_metric": cosine_similarity}, + "load_mnist_explanations_1", + ), + ], +) +def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + explanations = captum_similarity_explain( + model, + "test_id", + os.path.join("./cache", "test_id"), + test_tensor, + test_labels, + dataset, + device="cpu", + init_kwargs=method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 92618cca..f35c1906 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -5,7 +5,7 @@ @pytest.mark.localization_metrics @pytest.mark.parametrize( - "test_id, model, dataset, test_labels, batch_size, explanations, expected_score", + "test_id,model,dataset,test_labels,batch_size,explanations,expected_score", [ ( "mnist", diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index e65429a1..7716e581 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -48,8 +48,7 @@ def test_randomization_metric_functional( # Can we come up with a special attributor that gets exactly 0 score? metric.update(test_data=test_data, explanations=tda, explanation_targets=test_labels) out = metric.compute() - assert (out.item() >= -1.0) and (out.item() <= 1.0), "Test failed." - assert isinstance(out, torch.Tensor), "Output is not a tensor." + assert (out.item() >= -1.0) & (out.item() <= 1.0), "Test failed." @pytest.mark.randomization_metrics @@ -63,7 +62,7 @@ def test_randomization_metric_functional( ), ], ) -def test_model_randomization(test_id, model, dataset, request): +def test_randomization_metric_model_randomization(test_id, model, dataset, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) metric = ModelRandomizationMetric(model=model, train_dataset=dataset, explain_fn=lambda x: x, seed=42, device="cpu") diff --git a/tox.ini b/tox.ini index 0ba7947f..4fdee0da 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,7 @@ [flake8] max-line-length = 127 max-complexity = 10 +pytest-parametrize-names-type = csv ignore = E203 [testenv] From ce8b5383f9225b7f26d295a7b639444585bc27e1 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 10:07:09 +0200 Subject: [PATCH 24/30] resolve merge commit bugs --- src/explainers/utils.py | 2 -- src/explainers/wrappers/__init__.py | 0 .../{captum => wrappers}/captum_influence.py | 4 +--- src/utils/common.py | 12 ++++++++++++ tests/explainers/test_base_explainer.py | 2 +- 5 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 src/explainers/wrappers/__init__.py rename src/explainers/{captum => wrappers}/captum_influence.py (99%) diff --git a/src/explainers/utils.py b/src/explainers/utils.py index f7e121fc..e9a6b3a8 100644 --- a/src/explainers/utils.py +++ b/src/explainers/utils.py @@ -15,7 +15,6 @@ def explain_fn_from_explainer( init_kwargs: Optional[Dict] = None, explain_kwargs: Optional[Dict] = None, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} explain_kwargs = explain_kwargs or {} @@ -41,7 +40,6 @@ def self_influence_fn_from_explainer( init_kwargs: Optional[Dict] = None, explain_kwargs: Optional[Dict] = None, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} explain_kwargs = explain_kwargs or {} diff --git a/src/explainers/wrappers/__init__.py b/src/explainers/wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/explainers/captum/captum_influence.py b/src/explainers/wrappers/captum_influence.py similarity index 99% rename from src/explainers/captum/captum_influence.py rename to src/explainers/wrappers/captum_influence.py index ebcb76b5..cf051d69 100644 --- a/src/explainers/captum/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -3,7 +3,7 @@ import torch from captum.influence import SimilarityInfluence -from src.explainers.base_explainer import BaseExplainer +from src.explainers.base import BaseExplainer from src.explainers.utils import ( explain_fn_from_explainer, self_influence_fn_from_explainer, @@ -140,7 +140,6 @@ def captum_similarity_explain( init_kwargs: Optional[Dict] = None, explain_kwargs: Optional[Dict] = None, ) -> torch.Tensor: - init_kwargs = init_kwargs or {} explain_kwargs = explain_kwargs or {} @@ -166,7 +165,6 @@ def captum_similarity_self_influence( init_kwargs: Dict, device: Union[str, torch.device], ) -> torch.Tensor: - init_kwargs = init_kwargs or {} return self_influence_fn_from_explainer( diff --git a/src/utils/common.py b/src/utils/common.py index 1cb7fcf3..b07849ac 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -24,3 +24,15 @@ def make_func(func: Callable, func_kwargs: Optional[Mapping[str, Any]] = None, * _func_kwargs = kwargs return functools.partial(func, **_func_kwargs) + + +def cache_result(method): + cache_attr = f"_{method.__name__}_cache" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if cache_attr not in self.__dict__: + self.__dict__[cache_attr] = method(self, *args, **kwargs) + return self.__dict__[cache_attr] + + return wrapper diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index d4978c06..64374c41 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -4,7 +4,7 @@ import pytest import torch -from src.explainers.base_explainer import BaseExplainer +from src.explainers.base import BaseExplainer from src.utils.functions.similarities import cosine_similarity From c526bf0b012edc5eff5294b4ad46311908789039 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 10:47:04 +0200 Subject: [PATCH 25/30] self-influence refactoring --- src/explainers/base.py | 19 ++++++++++++++----- tests/conftest.py | 5 +++++ tests/explainers/test_base_explainer.py | 10 +++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/explainers/base.py b/src/explainers/base.py index d0bb5796..c696f9e7 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -30,12 +30,21 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T @cache_result def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch.Tensor: - # Base class implements computing self influences by explaining the train dataset one by one + """ + Base class implements computing self influences by explaining the train dataset one by one + + :param batch_size: + :param kwargs: + :return: + """ + + # Pre-allcate memory for influences, because torch.cat is slow influences = torch.empty((len(self.train_dataset),), device=self.device) ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size) - for i, (x, y) in enumerate(iter(ldr)): - upper_index = i * batch_size + x.shape[0] + + for i, (x, y) in zip(range(0, len(self.train_dataset), batch_size), ldr): explanations = self.explain(test=x.to(self.device), **kwargs) - explanations = explanations[:, i * batch_size : upper_index] - influences[i * batch_size : upper_index] = explanations.diag() + influences[i : i + batch_size] = explanations.diag(diagonal=i) + + # TODO: should we return just the ifnluences and not argsort? return influences.argsort() diff --git a/tests/conftest.py b/tests/conftest.py index 7f95f81c..98f9cb33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,6 +87,11 @@ def load_mnist_explanations_1(): return torch.load("tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") +@pytest.fixture +def load_mnist_dataset_explanations(): + return torch.rand((MINI_BATCH_SIZE, MINI_BATCH_SIZE)) + + @pytest.fixture def torch_cross_entropy_loss_object(): return torch.nn.CrossEntropyLoss() diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index 64374c41..ac86ec45 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -10,21 +10,21 @@ @pytest.mark.explainers @pytest.mark.parametrize( - "test_id, model, dataset, explanations, method_kwargs", + "test_id, model, dataset, dataset_xpl, method_kwargs", [ ( "mnist", "load_mnist_model", "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_dataset_explanations", {"layers": "relu_4", "similarity_metric": cosine_similarity}, ), ], ) -def test_base_explain_self_influence(test_id, model, dataset, explanations, method_kwargs, mocker, request): +def test_base_explain_self_influence(test_id, model, dataset, dataset_xpl, method_kwargs, mocker, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) - explanations = request.getfixturevalue(explanations) + dataset_xpl = request.getfixturevalue(dataset_xpl) BaseExplainer.__abstractmethods__ = set() explainer = BaseExplainer( @@ -38,7 +38,7 @@ def test_base_explain_self_influence(test_id, model, dataset, explanations, meth # Patch the method, because BaseExplainer has an abstract explain method. def mock_explain(test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None): - return explanations + return dataset_xpl[: test.shape[0], : test.shape[0]] mocker.patch.object(explainer, "explain", wraps=mock_explain) From 802a5c21f1e00f33f9472752c5fe66005877edee Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 11:07:07 +0200 Subject: [PATCH 26/30] using tmp_path fixture for tests --- .../wrappers/test_captum_influence.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index e88c353d..c0a1073b 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -1,5 +1,4 @@ import os -import shutil from collections import OrderedDict import pytest @@ -28,7 +27,7 @@ ], ) # TODO: I think a good naming convention is "test_..." or "test_...". -def test_self_influence(test_id, init_kwargs, request): +def test_self_influence(test_id, init_kwargs, tmp_path): # TODO: this should be a fixture. model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) @@ -38,11 +37,12 @@ def test_self_influence(test_id, init_kwargs, request): y = torch.randint(0, 10, (100,)) rand_dataset = TensorDataset(X, y) + # Using tmp_path pytest fixtures to create a temporary directory # TODO: One test should test one thing. This is test 1, .... self_influence_rank_functional = captum_similarity_self_influence( model=model, model_id="0", - cache_dir="temp_captum", + cache_dir=str(tmp_path), train_dataset=rand_dataset, init_kwargs=init_kwargs, device="cpu", @@ -53,7 +53,7 @@ def test_self_influence(test_id, init_kwargs, request): explainer_obj = CaptumSimilarity( model=model, model_id="1", - cache_dir="temp_captum2", + cache_dir=str(tmp_path), train_dataset=rand_dataset, device="cpu", **init_kwargs, @@ -63,13 +63,6 @@ def test_self_influence(test_id, init_kwargs, request): # TODO: here we then specifically test self_influence for CaptumSimilarity and should make it explicit in the name. self_influence_rank_stateful = explainer_obj.self_influence() - # TODO: we check "temp_captum2" but then remove os.path.join(os.getcwd(), "temp_captum2")? - # TODO: is there a reason to fear that the "temp_captum2" folder is not in os.getcwd()? - if os.path.isdir("temp_captum2"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) - if os.path.isdir("temp_captum"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum")) - # TODO: what if we pass a non-identity model? Then we don't expect torch.linalg.norm(X, dim=-1).argsort() # TODO: let's put expectations in the parametrisation of tests. We want to test different scenarios, # and not some super-specific case. This specific case definitely can be tested as well. @@ -96,7 +89,9 @@ def test_self_influence(test_id, init_kwargs, request): # TODO: I think a good naming convention is "test_..." or "test_...". # TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. # TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. -def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): +def test_explain_stateful( + test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request, tmp_path +): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) @@ -106,7 +101,7 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te explainer = CaptumSimilarity( model=model, model_id="test_id", - cache_dir=os.path.join("./cache", "test_id"), + cache_dir=str(tmp_path), train_dataset=dataset, device="cpu", **method_kwargs, From c1b4e49a78808866d8217430898c448ffc778ba9 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 11:08:37 +0200 Subject: [PATCH 27/30] using tmp_path fixture for tests cont. --- tests/explainers/wrappers/test_captum_influence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index c0a1073b..072a6a92 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -1,4 +1,3 @@ -import os from collections import OrderedDict import pytest @@ -106,7 +105,6 @@ def test_explain_stateful( device="cpu", **method_kwargs, ) - # TODO: activations folder clean-up explanations = explainer.explain(test_tensor) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" @@ -127,7 +125,9 @@ def test_explain_stateful( ), ], ) -def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request): +def test_explain_functional( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request, tmp_path +): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) @@ -136,7 +136,7 @@ def test_explain_functional(test_id, model, dataset, test_tensor, test_labels, m explanations = captum_similarity_explain( model, "test_id", - os.path.join("./cache", "test_id"), + str(tmp_path), test_tensor, test_labels, dataset, From 7d61b79d53542263e9984b27249024f09c636d39 Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 11:26:52 +0200 Subject: [PATCH 28/30] fix a tox issue --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4fdee0da..355fcf94 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,7 @@ ignore = E203 [testenv] description = Run the tests with {basepython} deps = - .[tests] + .[dev] commands = pytest -s -v {posargs} From d53320e9cc29d7d007174e5d097f5dc8a0ed9baf Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 12:51:25 +0200 Subject: [PATCH 29/30] make self-influence return attrs + small fixes --- src/explainers/base.py | 3 +-- src/explainers/wrappers/captum_influence.py | 8 +++----- tests/explainers/wrappers/test_captum_influence.py | 4 ++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/explainers/base.py b/src/explainers/base.py index c696f9e7..35691951 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -46,5 +46,4 @@ def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch explanations = self.explain(test=x.to(self.device), **kwargs) influences[i : i + batch_size] = explanations.diag(diagonal=i) - # TODO: should we return just the ifnluences and not argsort? - return influences.argsort() + return influences diff --git a/src/explainers/wrappers/captum_influence.py b/src/explainers/wrappers/captum_influence.py index cf051d69..e2dcc238 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/src/explainers/wrappers/captum_influence.py @@ -36,12 +36,10 @@ def __init__( ) self.explainer_cls = explainer_cls self.explain_kwargs = explain_kwargs - self._init_explainer(explainer_cls, **explain_kwargs) + self._init_explainer(**explain_kwargs) - def _init_explainer(self, cls: type, **explain_kwargs: Any): - self.captum_explainer = cls(**explain_kwargs) - if not isinstance(self.captum_explainer, self.explainer_cls): - raise ValueError(f"Expected {self.explainer_cls}, but got {type(self.captum_explainer)}") + def _init_explainer(self, **explain_kwargs: Any): + self.captum_explainer = self.explainer_cls(**explain_kwargs) def _process_targets(self, targets: Optional[Union[List[int], torch.Tensor]]): if targets is not None: diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 072a6a92..809e0eef 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -45,7 +45,7 @@ def test_self_influence(test_id, init_kwargs, tmp_path): train_dataset=rand_dataset, init_kwargs=init_kwargs, device="cpu", - ) + ).argsort() # TODO: ...this is test 2, unless we want to compare that the outputs are the same. # TODO: If we want to test that the outputs are the same, we should have a separate test for that. @@ -60,7 +60,7 @@ def test_self_influence(test_id, init_kwargs, tmp_path): # TODO: self_influence is defined in BaseExplainer - there is a test in test_base_explainer for that. # TODO: here we then specifically test self_influence for CaptumSimilarity and should make it explicit in the name. - self_influence_rank_stateful = explainer_obj.self_influence() + self_influence_rank_stateful = explainer_obj.self_influence().argsort() # TODO: what if we pass a non-identity model? Then we don't expect torch.linalg.norm(X, dim=-1).argsort() # TODO: let's put expectations in the parametrisation of tests. We want to test different scenarios, From c23bb254a399836a594d52d302f8c738a16b01cf Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 14:59:44 +0200 Subject: [PATCH 30/30] after-merge clean-up --- src/explainers/gradient_product_explainer.py | 0 tests/utils/test_explain_wrapper.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/explainers/gradient_product_explainer.py delete mode 100644 tests/utils/test_explain_wrapper.py diff --git a/src/explainers/gradient_product_explainer.py b/src/explainers/gradient_product_explainer.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/utils/test_explain_wrapper.py b/tests/utils/test_explain_wrapper.py deleted file mode 100644 index e69de29b..00000000