From 0f2b02b2f570d4f8f3868a2ab8fedc22dd635466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 17:31:40 +0200 Subject: [PATCH 01/23] delete whitespaces --- src/metrics/randomization/model_randomization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index a6405306..54d52794 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -79,9 +79,7 @@ def update( corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs) - def compute( - self, - ): + def compute(self): return torch.cat(self.results["rank_correlations"]).mean() def reset(self): From 0322da51ff10b2cfe6eb2c7cf73df229affa2707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 17:43:14 +0200 Subject: [PATCH 02/23] first globalization implementation --- src/utils/globalization/__init__.py | 0 src/utils/globalization/base.py | 10 ++++++++++ src/utils/globalization/from_explainer.py | 16 ++++++++++++++++ src/utils/globalization/from_explanations.py | 5 +++++ 4 files changed, 31 insertions(+) create mode 100644 src/utils/globalization/__init__.py create mode 100644 src/utils/globalization/base.py create mode 100644 src/utils/globalization/from_explainer.py create mode 100644 src/utils/globalization/from_explanations.py diff --git a/src/utils/globalization/__init__.py b/src/utils/globalization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/globalization/base.py b/src/utils/globalization/base.py new file mode 100644 index 00000000..a57c82ad --- /dev/null +++ b/src/utils/globalization/base.py @@ -0,0 +1,10 @@ +import torch + +class Globalization(): + def __init__(self, training_dataset, *args,**kwargs): + self.dataset=training_dataset + self.scores=torch.zeros((len(training_dataset))) + raise NotImplementedError + + def get_global_ranking(self): + return self.scores.argmax() \ No newline at end of file diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py new file mode 100644 index 00000000..509297f7 --- /dev/null +++ b/src/utils/globalization/from_explainer.py @@ -0,0 +1,16 @@ +from src.utils.globalization.base import Globalization + +class GlobalizationFromSingleImageAttributor(Globalization): + def __init__(self, training_dataset, model, attributor_fn, attributor_fn_kwargs): + # why is it called attributor + super().__init__(training_dataset=training_dataset) + self.attributor_fn = attributor_fn + self.model=model + + def compute_self_influences(self): + for i, (x,_) in enumerate(self.training_dataset): + self.scores[i]=self.attributor_fn(datapoint=x) + + def update_self_influences(self, self_influences): + self.scores=self_influences + \ No newline at end of file diff --git a/src/utils/globalization/from_explanations.py b/src/utils/globalization/from_explanations.py new file mode 100644 index 00000000..ef35e8f9 --- /dev/null +++ b/src/utils/globalization/from_explanations.py @@ -0,0 +1,5 @@ +from src.utils.globalization.base import Globalization + +class GlobalizationFromExplanations(Globalization): + def update(self, explanations): + self.scores += explanations.abs().sum(dim=0) \ No newline at end of file From b75edd9b1976d6f00f92c615227ee046221a71d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 17:46:07 +0200 Subject: [PATCH 03/23] minor fix to top_k_overlap --- src/metrics/unnamed/top_k_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index 66a3b8b8..70cebea0 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -29,7 +29,7 @@ def compute(self, *args, **kwargs): return len(torch.unique(self.all_top_k_examples)) def reset(self, *args, **kwargs): - self.all_top_k_examples = [] + self.all_top_k_examples = torch.empty(0, top_k) def load_state_dict(self, state_dict: dict, *args, **kwargs): self.all_top_k_examples = state_dict["all_top_k_examples"] From 37a78e73655f6d0283bce3d4f15693562a184a0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 18:01:42 +0200 Subject: [PATCH 04/23] add make_func --- src/utils/globalization/from_explainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py index 509297f7..86a7f638 100644 --- a/src/utils/globalization/from_explainer.py +++ b/src/utils/globalization/from_explainer.py @@ -1,10 +1,12 @@ from src.utils.globalization.base import Globalization +from src.utils.common import make_func class GlobalizationFromSingleImageAttributor(Globalization): def __init__(self, training_dataset, model, attributor_fn, attributor_fn_kwargs): # why is it called attributor super().__init__(training_dataset=training_dataset) - self.attributor_fn = attributor_fn + self.attributor_fn = make_func(func=attributor_fn, func_kwargs=explain_fn_kwargs, model=self.model) + self.model=model def compute_self_influences(self): From 501d9c9905df91e949bfde4a9229056ec03b4f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 18:03:54 +0200 Subject: [PATCH 05/23] small fix on top_k_overlap --- src/metrics/unnamed/top_k_overlap.py | 2 +- src/utils/globalization/from_explainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index 70cebea0..e5f70235 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -29,7 +29,7 @@ def compute(self, *args, **kwargs): return len(torch.unique(self.all_top_k_examples)) def reset(self, *args, **kwargs): - self.all_top_k_examples = torch.empty(0, top_k) + self.all_top_k_examples = torch.empty(0, self.top_k) def load_state_dict(self, state_dict: dict, *args, **kwargs): self.all_top_k_examples = state_dict["all_top_k_examples"] diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py index 86a7f638..e111a2b4 100644 --- a/src/utils/globalization/from_explainer.py +++ b/src/utils/globalization/from_explainer.py @@ -5,7 +5,7 @@ class GlobalizationFromSingleImageAttributor(Globalization): def __init__(self, training_dataset, model, attributor_fn, attributor_fn_kwargs): # why is it called attributor super().__init__(training_dataset=training_dataset) - self.attributor_fn = make_func(func=attributor_fn, func_kwargs=explain_fn_kwargs, model=self.model) + self.attributor_fn = make_func(func=attributor_fn, func_kwargs=attributor_fn_kwargs, model=self.model) self.model=model From df0691e309cf9e90eea485371332df2e599a0ff8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 18:11:53 +0200 Subject: [PATCH 06/23] code style fixes --- src/metrics/localization/identical_class.py | 6 +----- src/utils/explain_wrapper.py | 2 +- src/utils/globalization/base.py | 11 ++++++----- src/utils/globalization/from_explainer.py | 15 +++++++-------- src/utils/globalization/from_explanations.py | 3 ++- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 30c80413..63e06638 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -15,11 +15,7 @@ def __init__( super().__init__(model, train_dataset, device, *args, **kwargs) self.scores = [] - def update( - self, - test_labels: torch.Tensor, - explanations: torch.Tensor - ): + def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): """ Used to implement metric-specific logic. """ diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 3f2191fc..03e4775e 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -16,7 +16,7 @@ def __call__( test_tensor: torch.Tensor, method: str, ) -> torch.Tensor: - ... + pass def explain( diff --git a/src/utils/globalization/base.py b/src/utils/globalization/base.py index a57c82ad..4a319f05 100644 --- a/src/utils/globalization/base.py +++ b/src/utils/globalization/base.py @@ -1,10 +1,11 @@ import torch -class Globalization(): - def __init__(self, training_dataset, *args,**kwargs): - self.dataset=training_dataset - self.scores=torch.zeros((len(training_dataset))) + +class Globalization: + def __init__(self, training_dataset, *args, **kwargs): + self.dataset = training_dataset + self.scores = torch.zeros((len(training_dataset))) raise NotImplementedError def get_global_ranking(self): - return self.scores.argmax() \ No newline at end of file + return self.scores.argmax() diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py index e111a2b4..7889329e 100644 --- a/src/utils/globalization/from_explainer.py +++ b/src/utils/globalization/from_explainer.py @@ -1,18 +1,17 @@ -from src.utils.globalization.base import Globalization from src.utils.common import make_func +from src.utils.globalization.base import Globalization + class GlobalizationFromSingleImageAttributor(Globalization): def __init__(self, training_dataset, model, attributor_fn, attributor_fn_kwargs): # why is it called attributor super().__init__(training_dataset=training_dataset) self.attributor_fn = make_func(func=attributor_fn, func_kwargs=attributor_fn_kwargs, model=self.model) + self.model = model - self.model=model - def compute_self_influences(self): - for i, (x,_) in enumerate(self.training_dataset): - self.scores[i]=self.attributor_fn(datapoint=x) - + for i, (x, _) in enumerate(self.training_dataset): + self.scores[i] = self.attributor_fn(datapoint=x) + def update_self_influences(self, self_influences): - self.scores=self_influences - \ No newline at end of file + self.scores = self_influences diff --git a/src/utils/globalization/from_explanations.py b/src/utils/globalization/from_explanations.py index ef35e8f9..9a0a0368 100644 --- a/src/utils/globalization/from_explanations.py +++ b/src/utils/globalization/from_explanations.py @@ -1,5 +1,6 @@ from src.utils.globalization.base import Globalization + class GlobalizationFromExplanations(Globalization): def update(self, explanations): - self.scores += explanations.abs().sum(dim=0) \ No newline at end of file + self.scores += explanations.abs().sum(dim=0) From 7e34a3d9b90254add18d910b6e7eb5a667ffa3e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 4 Jun 2024 18:19:05 +0200 Subject: [PATCH 07/23] type hints --- src/utils/globalization/base.py | 2 +- src/utils/globalization/from_explainer.py | 12 +++++++++++- src/utils/globalization/from_explanations.py | 4 +++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/utils/globalization/base.py b/src/utils/globalization/base.py index 4a319f05..947b5266 100644 --- a/src/utils/globalization/base.py +++ b/src/utils/globalization/base.py @@ -2,7 +2,7 @@ class Globalization: - def __init__(self, training_dataset, *args, **kwargs): + def __init__(self, training_dataset: torch.utils.data.Dataset, *args, **kwargs): self.dataset = training_dataset self.scores = torch.zeros((len(training_dataset))) raise NotImplementedError diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py index 7889329e..13c03c70 100644 --- a/src/utils/globalization/from_explainer.py +++ b/src/utils/globalization/from_explainer.py @@ -1,9 +1,19 @@ +from typing import Callable, Optional + +import torch + from src.utils.common import make_func from src.utils.globalization.base import Globalization class GlobalizationFromSingleImageAttributor(Globalization): - def __init__(self, training_dataset, model, attributor_fn, attributor_fn_kwargs): + def __init__( + self, + training_dataset: torch.utils.data.Dataset, + model: torch.nn.Module, + attributor_fn: Callable, + attributor_fn_kwargs: Optional[dict] = None, + ): # why is it called attributor super().__init__(training_dataset=training_dataset) self.attributor_fn = make_func(func=attributor_fn, func_kwargs=attributor_fn_kwargs, model=self.model) diff --git a/src/utils/globalization/from_explanations.py b/src/utils/globalization/from_explanations.py index 9a0a0368..8d027655 100644 --- a/src/utils/globalization/from_explanations.py +++ b/src/utils/globalization/from_explanations.py @@ -1,6 +1,8 @@ +import torch + from src.utils.globalization.base import Globalization class GlobalizationFromExplanations(Globalization): - def update(self, explanations): + def update(self, explanations: torch.Tensor): self.scores += explanations.abs().sum(dim=0) From 85f3b4ba525e0ff57bbc9ee08c13e0227d37b03e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sun, 9 Jun 2024 19:19:43 +0200 Subject: [PATCH 08/23] ExplanationsAggregator and get_self_influence_ranking --- src/utils/aggregators.py | 24 +++++++++++++++++ src/utils/common.py | 21 ++++++++++++++- src/utils/explain_wrapper.py | 12 +++++++++ src/utils/globalization/__init__.py | 0 src/utils/globalization/base.py | 11 -------- src/utils/globalization/from_explainer.py | 27 -------------------- src/utils/globalization/from_explanations.py | 8 ------ 7 files changed, 56 insertions(+), 47 deletions(-) create mode 100644 src/utils/aggregators.py delete mode 100644 src/utils/globalization/__init__.py delete mode 100644 src/utils/globalization/base.py delete mode 100644 src/utils/globalization/from_explainer.py delete mode 100644 src/utils/globalization/from_explanations.py diff --git a/src/utils/aggregators.py b/src/utils/aggregators.py new file mode 100644 index 00000000..df07e825 --- /dev/null +++ b/src/utils/aggregators.py @@ -0,0 +1,24 @@ +from abc import ABC + +import torch + + +class ExplanationsAggregator(ABC): + def __init__(self, training_size: int, *args, **kwargs): + self.scores = torch.zeros(training_size) + + def update(self, explanations: torch.Tensor): + raise NotImplementedError + + def get_global_ranking(self) -> torch.Tensor: + return self.scores.argsort() + + +class SumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.sum(dim=0) + + +class AbsSumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.abs().sum(dim=0) diff --git a/src/utils/common.py b/src/utils/common.py index 45e34a05..b5686ddf 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,8 +1,12 @@ import functools from functools import reduce -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, Optional import torch +import torch.utils +import torch.utils.data + +from utils.explain_wrapper import SelfInfluenceFunction def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: @@ -22,3 +26,18 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) - func_kwargs = kwargs return functools.partial(func, **func_kwargs) + + +def get_self_influence_ranking( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + training_data: torch.utils.data.Dataset, + self_influence_fn: SelfInfluenceFunction, + self_influence_fn_kwargs: Optional[dict] = None, +) -> torch.Tensor: + size = len(training_data) + self_inf = torch.zeros((size,)) + for i, (x, y) in enumerate(training_data): + self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i) + return self_inf.argsort() diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 03e4775e..367b6f5c 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -19,6 +19,18 @@ def __call__( pass +class SelfInfluenceFunction(Protocol): + def __call__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + id: int, + ) -> torch.Tensor: + pass + + def explain( model: torch.nn.Module, model_id: str, diff --git a/src/utils/globalization/__init__.py b/src/utils/globalization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/globalization/base.py b/src/utils/globalization/base.py deleted file mode 100644 index 947b5266..00000000 --- a/src/utils/globalization/base.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - - -class Globalization: - def __init__(self, training_dataset: torch.utils.data.Dataset, *args, **kwargs): - self.dataset = training_dataset - self.scores = torch.zeros((len(training_dataset))) - raise NotImplementedError - - def get_global_ranking(self): - return self.scores.argmax() diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py deleted file mode 100644 index 13c03c70..00000000 --- a/src/utils/globalization/from_explainer.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Callable, Optional - -import torch - -from src.utils.common import make_func -from src.utils.globalization.base import Globalization - - -class GlobalizationFromSingleImageAttributor(Globalization): - def __init__( - self, - training_dataset: torch.utils.data.Dataset, - model: torch.nn.Module, - attributor_fn: Callable, - attributor_fn_kwargs: Optional[dict] = None, - ): - # why is it called attributor - super().__init__(training_dataset=training_dataset) - self.attributor_fn = make_func(func=attributor_fn, func_kwargs=attributor_fn_kwargs, model=self.model) - self.model = model - - def compute_self_influences(self): - for i, (x, _) in enumerate(self.training_dataset): - self.scores[i] = self.attributor_fn(datapoint=x) - - def update_self_influences(self, self_influences): - self.scores = self_influences diff --git a/src/utils/globalization/from_explanations.py b/src/utils/globalization/from_explanations.py deleted file mode 100644 index 8d027655..00000000 --- a/src/utils/globalization/from_explanations.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch - -from src.utils.globalization.base import Globalization - - -class GlobalizationFromExplanations(Globalization): - def update(self, explanations: torch.Tensor): - self.scores += explanations.abs().sum(dim=0) From 43c950efb65ea724f6f9452390f540b1de5c1040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sun, 9 Jun 2024 19:36:43 +0200 Subject: [PATCH 09/23] use self_influence_fn_kwargs --- src/utils/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/common.py b/src/utils/common.py index b5686ddf..ff833f75 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -39,5 +39,5 @@ def get_self_influence_ranking( size = len(training_data) self_inf = torch.zeros((size,)) for i, (x, y) in enumerate(training_data): - self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i) + self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **self_influence_fn_kwargs) return self_inf.argsort() From 80023ba98e81be0ca1cd648555651e9163cba005 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 12 Jun 2024 17:44:40 +0200 Subject: [PATCH 10/23] Changes suggested by Dilya --- .../aggregators}/aggregators.py | 23 +++++++++++-- src/explainers/aggregators/self_influence.py | 33 +++++++++++++++++++ src/utils/common.py | 19 +---------- src/utils/explain_wrapper.py | 12 ------- 4 files changed, 55 insertions(+), 32 deletions(-) rename src/{utils => explainers/aggregators}/aggregators.py (51%) create mode 100644 src/explainers/aggregators/self_influence.py diff --git a/src/utils/aggregators.py b/src/explainers/aggregators/aggregators.py similarity index 51% rename from src/utils/aggregators.py rename to src/explainers/aggregators/aggregators.py index df07e825..0d36b6b4 100644 --- a/src/utils/aggregators.py +++ b/src/explainers/aggregators/aggregators.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod import torch @@ -7,10 +7,29 @@ class ExplanationsAggregator(ABC): def __init__(self, training_size: int, *args, **kwargs): self.scores = torch.zeros(training_size) + @abstractmethod def update(self, explanations: torch.Tensor): raise NotImplementedError - def get_global_ranking(self) -> torch.Tensor: + def reset(self, *args, **kwargs): + """ + Used to reset the aggregator state. + """ + self.scores = torch.zeros_like(self.scores) + + def load_state_dict(self, state_dict: dict, *args, **kwargs): + """ + Used to load the aggregator state. + """ + self.scores = state_dict["scores"] + + def state_dict(self, *args, **kwargs): + """ + Used to return the metric state. + """ + return {"scores": self.scores} + + def compute(self) -> torch.Tensor: return self.scores.argsort() diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py new file mode 100644 index 00000000..32b0c6e8 --- /dev/null +++ b/src/explainers/aggregators/self_influence.py @@ -0,0 +1,33 @@ +from typing import Optional, Protocol + +import torch + +from utils.common import make_func + + +class SelfInfluenceFunction(Protocol): + def __call__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + id: int, + ) -> torch.Tensor: + pass + + +def get_self_influence_ranking( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + training_data: torch.utils.data.Dataset, + self_influence_fn: SelfInfluenceFunction, + self_influence_fn_kwargs: Optional[dict] = None, +) -> torch.Tensor: + size = len(training_data) + self_inf = torch.zeros((size,)) + self_influence_fn = make_func + for i, (x, y) in enumerate(training_data): + self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **self_influence_fn_kwargs) + return self_inf.argsort() diff --git a/src/utils/common.py b/src/utils/common.py index ff833f75..777c8e7a 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,13 +1,11 @@ import functools from functools import reduce -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Mapping import torch import torch.utils import torch.utils.data -from utils.explain_wrapper import SelfInfluenceFunction - def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: return reduce(getattr, layer_name.split("."), model) @@ -26,18 +24,3 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) - func_kwargs = kwargs return functools.partial(func, **func_kwargs) - - -def get_self_influence_ranking( - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - training_data: torch.utils.data.Dataset, - self_influence_fn: SelfInfluenceFunction, - self_influence_fn_kwargs: Optional[dict] = None, -) -> torch.Tensor: - size = len(training_data) - self_inf = torch.zeros((size,)) - for i, (x, y) in enumerate(training_data): - self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **self_influence_fn_kwargs) - return self_inf.argsort() diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 367b6f5c..03e4775e 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -19,18 +19,6 @@ def __call__( pass -class SelfInfluenceFunction(Protocol): - def __call__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - id: int, - ) -> torch.Tensor: - pass - - def explain( model: torch.nn.Module, model_id: str, From 23abeec7f2f1539ab142d80b010d495230e51191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 12 Jun 2024 17:53:14 +0200 Subject: [PATCH 11/23] give all explainer parameters to make_func in randomization metric --- src/metrics/localization/identical_class.py | 6 +----- .../randomization/model_randomization.py | 17 +++++++++-------- src/utils/explain_wrapper.py | 2 +- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 30c80413..63e06638 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -15,11 +15,7 @@ def __init__( super().__init__(model, train_dataset, device, *args, **kwargs) self.scores = [] - def update( - self, - test_labels: torch.Tensor, - explanations: torch.Tensor - ): + def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): """ Used to implement metric-specific logic. """ diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index a6405306..b551c07d 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -51,7 +51,14 @@ def __init__( # or do we want genuinely random models for each call of the metric (keeping seed in the constructor) self.generator = torch.Generator(device=device) self.rand_model = self._randomize_model(model) - self.explain_fn = make_func(func=explain_fn, func_kwargs=explain_fn_kwargs, model=self.rand_model) + self.explain_fn = make_func( + func=explain_fn, + func_kwargs=explain_fn_kwargs, + model=self.rand_model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + ) self.results = {"rank_correlations": []} if isinstance(correlation_fn, str) and correlation_fn in correlation_functions: @@ -69,13 +76,7 @@ def update( test_data: torch.Tensor, explanations: torch.Tensor, ): - rand_explanations = self.explain_fn( - model=self.rand_model, - model_id=self.model_id, - cache_dir=self.cache_dir, - train_dataset=self.train_dataset, - test_tensor=test_data, - ) + rand_explanations = self.explain_fn(test_tensor=test_data) corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs) diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 3f2191fc..03e4775e 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -16,7 +16,7 @@ def __call__( test_tensor: torch.Tensor, method: str, ) -> torch.Tensor: - ... + pass def explain( From 6e6365923c4a788a8279cddadc13b2c6f259c473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 12 Jun 2024 18:15:50 +0200 Subject: [PATCH 12/23] add randomized model state dict to the metric state dict --- src/metrics/randomization/model_randomization.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index b551c07d..ad4fb6c0 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -87,12 +87,18 @@ def compute( def reset(self): self.results = {"rank_correlations": []} + self.rand_model = self._randomize_model(self.model) def state_dict(self): - return self.results + state_dict = {} + state_dict.update(self.results) + state_dict["random_model_state_dict"] = self.model.state_dict() + return state_dict def load_state_dict(self, state_dict: dict): - self.results = state_dict + self.rand_model.load_state_dict(state_dict["random_model_state_dict"]) + state_dict.pop("random_model_state_dict", None) + self.results.update(state_dict) def _randomize_model(self, model: torch.nn.Module): rand_model = copy.deepcopy(model) From c5fb56b6f6ae865aff42f1136dbb160863b469a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 12 Jun 2024 19:23:44 +0200 Subject: [PATCH 13/23] get rid of functoolss.partial because it is very hard to use just to make 1 line smaller in each metric --- .../randomization/model_randomization.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index ad4fb6c0..dca5bb3b 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -4,7 +4,7 @@ import torch from metrics.base import Metric -from utils.common import _get_parent_module_from_name, make_func +from utils.common import _get_parent_module_from_name from utils.explain_wrapper import ExplainFunc from utils.functions.correlations import ( CorrelationFnLiterals, @@ -51,14 +51,8 @@ def __init__( # or do we want genuinely random models for each call of the metric (keeping seed in the constructor) self.generator = torch.Generator(device=device) self.rand_model = self._randomize_model(model) - self.explain_fn = make_func( - func=explain_fn, - func_kwargs=explain_fn_kwargs, - model=self.rand_model, - model_id=model_id, - cache_dir=cache_dir, - train_dataset=train_dataset, - ) + self.explain_fn = explain_fn + self.results = {"rank_correlations": []} if isinstance(correlation_fn, str) and correlation_fn in correlation_functions: @@ -76,7 +70,14 @@ def update( test_data: torch.Tensor, explanations: torch.Tensor, ): - rand_explanations = self.explain_fn(test_tensor=test_data) + rand_explanations = self.explain_fn( + model=self.rand_model, + model_id=self.model_id, + cache_dir=self.cache_dir, + train_dataset=self.train_dataset, + test_tensor=test_data, + **self.explain_fn_kwargs, + ) corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs) From ecc621a80c17785c68c6f4af0c0c4d43c053b805 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 12 Jun 2024 22:59:25 +0200 Subject: [PATCH 14/23] fix bug in make_func, add make_func functionality to randomization metric --- .../randomization/model_randomization.py | 37 +++++++++++-------- src/usage.py | 21 ----------- src/utils/common.py | 4 +- 3 files changed, 24 insertions(+), 38 deletions(-) delete mode 100644 src/usage.py diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index dca5bb3b..4970e326 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -4,7 +4,7 @@ import torch from metrics.base import Metric -from utils.common import _get_parent_module_from_name +from utils.common import _get_parent_module_from_name, make_func from utils.explain_wrapper import ExplainFunc from utils.functions.correlations import ( CorrelationFnLiterals, @@ -50,8 +50,15 @@ def __init__( # do we want the exact same random model to be attributed (keeping seed in the __call__ call) # or do we want genuinely random models for each call of the metric (keeping seed in the constructor) self.generator = torch.Generator(device=device) + self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(model) - self.explain_fn = explain_fn + self.explain_fn = make_func( + func=explain_fn, + func_kwargs=explain_fn_kwargs, + model_id=self.model_id, + cache_dir=self.cache_dir, + train_dataset=self.train_dataset, + ) self.results = {"rank_correlations": []} @@ -70,14 +77,7 @@ def update( test_data: torch.Tensor, explanations: torch.Tensor, ): - rand_explanations = self.explain_fn( - model=self.rand_model, - model_id=self.model_id, - cache_dir=self.cache_dir, - train_dataset=self.train_dataset, - test_tensor=test_data, - **self.explain_fn_kwargs, - ) + rand_explanations = self.explain_fn(model=self.rand_model, test_tensor=test_data) corrs = self.correlation_measure(explanations, rand_explanations) self.results["rank_correlations"].append(corrs) @@ -88,18 +88,25 @@ def compute( def reset(self): self.results = {"rank_correlations": []} + self.generator.manual_seed(self.seed) self.rand_model = self._randomize_model(self.model) def state_dict(self): - state_dict = {} - state_dict.update(self.results) - state_dict["random_model_state_dict"] = self.model.state_dict() + state_dict = { + "results_dict": self.results, + "random_model_state_dict": self.model.state_dict(), + "seed": self.seed, + "generator_state": self.generator.get_state(), + "explain_fn": self.explain_fn, + } return state_dict def load_state_dict(self, state_dict: dict): + self.results = state_dict["results_dict"] + self.seed = state_dict["seed"] + self.explain_fn = state_dict["explain_fn"] self.rand_model.load_state_dict(state_dict["random_model_state_dict"]) - state_dict.pop("random_model_state_dict", None) - self.results.update(state_dict) + self.generator.set_state(state_dict["generator_state"]) def _randomize_model(self, model: torch.nn.Module): rand_model = copy.deepcopy(model) diff --git a/src/usage.py b/src/usage.py deleted file mode 100644 index 615cc356..00000000 --- a/src/usage.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch.utils.data import DataLoader -from torchvision.datasets import MNIST -from torchvision.models import resnet18 - -from src.explainers.base import Explainer -from src.metrics.base import Metric - -train_ds = MNIST(root="~/Documents/Code/Datasets", train=True) -test_ds = MNIST(root="~/Documents/Code/Datasets", train=False) -test_ld = DataLoader(test_ds, batch_size=32) -model = resnet18() -# Possibly get special kinds of datasets here -metric = Metric(train_ds, test_ds) -# Possibly train model on the special kind of dataset with something like metric.train_model() -explainer = Explainer(model, train_ds, "cuda") -explainer.train() -for x, y in iter(test_ld): - preds = model(x).argmax(dim=-1) - xpl = explainer.explain(x, preds) - metric(xpl) -metric.get_result() diff --git a/src/utils/common.py b/src/utils/common.py index 45e34a05..d63bcc4b 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -19,6 +19,6 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) - _func_kwargs = kwargs.copy() _func_kwargs.update(func_kwargs) else: - func_kwargs = kwargs + _func_kwargs = kwargs - return functools.partial(func, **func_kwargs) + return functools.partial(func, **_func_kwargs) From 91409e225027282bce385ba3eca4a751af2e4a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 17:32:44 +0200 Subject: [PATCH 15/23] changes needed for using explain_fn for self influence --- src/explainers/aggregators/self_influence.py | 24 +++++++------------- src/metrics/unnamed/top_k_overlap.py | 2 +- src/utils/explain_wrapper.py | 16 +++++++++---- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 32b0c6e8..228b3a61 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -1,20 +1,10 @@ -from typing import Optional, Protocol +from typing import Optional +from warnings import warn import torch from utils.common import make_func - - -class SelfInfluenceFunction(Protocol): - def __call__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - id: int, - ) -> torch.Tensor: - pass +from utils.explain_wrapper import ExplainFunc def get_self_influence_ranking( @@ -22,12 +12,14 @@ def get_self_influence_ranking( model_id: str, cache_dir: Optional[str], training_data: torch.utils.data.Dataset, - self_influence_fn: SelfInfluenceFunction, - self_influence_fn_kwargs: Optional[dict] = None, + explain_fn: ExplainFunc, + explain_fn_kwargs: Optional[dict] = None, ) -> torch.Tensor: + if "train_ids" not in explain_fn_kwargs: + warn("train_id is supplied to compute self-influences. Supplied indices will be ignored.") size = len(training_data) self_inf = torch.zeros((size,)) self_influence_fn = make_func for i, (x, y) in enumerate(training_data): - self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **self_influence_fn_kwargs) + self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs) return self_inf.argsort() diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index e5f70235..cb872726 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -13,7 +13,7 @@ def __init__( *args, **kwargs, ): - super().__init__(model, train_dataset, *args, **kwargs) + # super().__init__(model, train_dataset, *args, **kwargs) self.top_k = top_k self.all_top_k_examples = torch.empty(0, top_k) diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 03e4775e..ebe49ec4 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -1,8 +1,9 @@ -from typing import Optional, Protocol +from typing import List, Optional, Protocol, Union import torch from captum.influence import SimilarityInfluence +from src.utils.datasets.indexed_subset import IndexedSubset from src.utils.functions.similarities import cosine_similarity @@ -12,9 +13,10 @@ def __call__( model: torch.nn.Module, model_id: str, cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - test_tensor: torch.Tensor, method: str, + test_tensor: torch.Tensor, + train_dataset: torch.utils.data.Dataset, + train_ids: Optional[Union[List[int], torch.Tensor]] = None, ) -> torch.Tensor: pass @@ -23,9 +25,10 @@ def explain( model: torch.nn.Module, model_id: str, cache_dir: str, - train_dataset: torch.utils.data.Dataset, - test_tensor: torch.Tensor, method: str, + test_tensor: torch.Tensor, + train_dataset: torch.utils.data.Dataset, + train_ids: Optional[Union[List[int], torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: """ @@ -40,6 +43,9 @@ def explain( :param kwargs: :return: """ + if train_ids is not None: + train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids) + if method == "SimilarityInfluence": layer = kwargs.get("layer", "features") sim_metric = kwargs.get("similarity_metric", cosine_similarity) From 3256cbb6f932688f91fe72689369188eb1f68269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 18:53:47 +0200 Subject: [PATCH 16/23] minor fix --- src/explainers/aggregators/self_influence.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 228b3a61..269bdb6c 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -3,7 +3,6 @@ import torch -from utils.common import make_func from utils.explain_wrapper import ExplainFunc @@ -19,7 +18,6 @@ def get_self_influence_ranking( warn("train_id is supplied to compute self-influences. Supplied indices will be ignored.") size = len(training_data) self_inf = torch.zeros((size,)) - self_influence_fn = make_func for i, (x, y) in enumerate(training_data): - self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs) + self_inf[i] = explain_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs) return self_inf.argsort() From 2b2ac6e56a5e63b3f8fa5888b19073b02262ae0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 18:55:05 +0200 Subject: [PATCH 17/23] minor fix --- src/explainers/aggregators/self_influence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 269bdb6c..45146dc5 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -9,7 +9,7 @@ def get_self_influence_ranking( model: torch.nn.Module, model_id: str, - cache_dir: Optional[str], + cache_dir: str, training_data: torch.utils.data.Dataset, explain_fn: ExplainFunc, explain_fn_kwargs: Optional[dict] = None, From e5682f0a55b7f60ca47cca12b310cd0c083d56ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 19:08:35 +0200 Subject: [PATCH 18/23] add test_target to explainer_fn parameters, because most methods need an output neuron to explain --- src/explainers/aggregators/self_influence.py | 12 ++++++++++-- src/utils/explain_wrapper.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 45146dc5..b0f83a14 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -19,5 +19,13 @@ def get_self_influence_ranking( size = len(training_data) self_inf = torch.zeros((size,)) for i, (x, y) in enumerate(training_data): - self_inf[i] = explain_fn(model, model_id, cache_dir, training_data, i, **explain_fn_kwargs) - return self_inf.argsort() + self_inf[i] = explain_fn( + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=x[None], + test_label=y[None], + train_dataset=training_data, + train_ids=[i], + **explain_fn_kwargs, + ) diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index ebe49ec4..43607047 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -26,8 +26,9 @@ def explain( model_id: str, cache_dir: str, method: str, - test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, + test_tensor: torch.Tensor, + test_target: Optional[torch.Tensor] = None, train_ids: Optional[Union[List[int], torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: From 3649eb8d978fdade1cfe0dd1d1ba413e8d58578f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 13 Jun 2024 19:18:45 +0200 Subject: [PATCH 19/23] return ranking instead of self-influence values --- src/explainers/aggregators/self_influence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index b0f83a14..e7be7367 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -29,3 +29,4 @@ def get_self_influence_ranking( train_ids=[i], **explain_fn_kwargs, ) + return self_inf.argsort() From 7254e17a3c5fc5413967d3a15ba6bacf7ebb8af4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 14 Jun 2024 19:11:27 +0200 Subject: [PATCH 20/23] changes for test# --- src/explainers/aggregators/self_influence.py | 3 ++- src/utils/explain_wrapper.py | 5 ++--- src/utils/functions/similarities.py | 20 ++++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index e7be7367..336e93fd 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -18,10 +18,11 @@ def get_self_influence_ranking( warn("train_id is supplied to compute self-influences. Supplied indices will be ignored.") size = len(training_data) self_inf = torch.zeros((size,)) + for i, (x, y) in enumerate(training_data): self_inf[i] = explain_fn( model=model, - model_id=model_id, + model_id=f"{model_id}_id_{i}", cache_dir=cache_dir, test_tensor=x[None], test_label=y[None], diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 43607047..bb82460d 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -44,10 +44,9 @@ def explain( :param kwargs: :return: """ - if train_ids is not None: - train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids) - if method == "SimilarityInfluence": + if train_ids is not None: + train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids) layer = kwargs.get("layer", "features") sim_metric = kwargs.get("similarity_metric", cosine_similarity) sim_direction = kwargs.get("similarity_direction", "max") diff --git a/src/utils/functions/similarities.py b/src/utils/functions/similarities.py index 08675713..699f705a 100644 --- a/src/utils/functions/similarities.py +++ b/src/utils/functions/similarities.py @@ -25,3 +25,23 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor: similarity = torch.mm(test, train) return similarity + + +def dot_product_similarity(test, train, replace_nan=0) -> Tensor: + """ + Compute cosine similarity between test and train activations. + + :param test: + :param train: + :param replace_nan: + :return: + """ + # TODO: I don't know why Captum return test activations as a list + if isinstance(test, list): + test = torch.cat(test) + assert torch.all(test == train) + test = test.view(test.shape[0], -1) + train = train.view(train.shape[0], -1) + + similarity = torch.mm(test, train.T) + return similarity From 75fab05db787674a72daa0065b0986c8f0a67f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 14 Jun 2024 19:29:47 +0200 Subject: [PATCH 21/23] aggregator and self influence tests --- .../aggregators/test_aggregators.py | 47 +++++++++++++++++++ .../aggregators/test_self_influence.py | 38 +++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 tests/explainers/aggregators/test_aggregators.py create mode 100644 tests/explainers/aggregators/test_self_influence.py diff --git a/tests/explainers/aggregators/test_aggregators.py b/tests/explainers/aggregators/test_aggregators.py new file mode 100644 index 00000000..cfab2355 --- /dev/null +++ b/tests/explainers/aggregators/test_aggregators.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from src.explainers.aggregators.aggregators import ( + AbsSumAggregator, + SumAggregator, +) + + +@pytest.mark.aggregators +@pytest.mark.parametrize( + "test_id, dataset, explanations", + [ + ( + "mnist", + "load_mnist_dataset", + "load_mnist_explanations_1", + ), + ], +) +def test_sum_aggregator(test_id, dataset, explanations, request): + dataset = request.getfixturevalue(dataset) + explanations = request.getfixturevalue(explanations) + aggregator = SumAggregator(training_size=len(dataset)) + aggregator.update(explanations) + global_rank = aggregator.compute() + assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) + + +@pytest.mark.aggregators +@pytest.mark.parametrize( + "test_id, dataset, explanations", + [ + ( + "mnist", + "load_mnist_dataset", + "load_mnist_explanations_1", + ), + ], +) +def test_abs_aggregator(test_id, dataset, explanations, request): + dataset = request.getfixturevalue(dataset) + explanations = request.getfixturevalue(explanations) + aggregator = AbsSumAggregator(training_size=len(dataset)) + aggregator.update(explanations) + global_rank = aggregator.compute() + assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) diff --git a/tests/explainers/aggregators/test_self_influence.py b/tests/explainers/aggregators/test_self_influence.py new file mode 100644 index 00000000..b43b9dae --- /dev/null +++ b/tests/explainers/aggregators/test_self_influence.py @@ -0,0 +1,38 @@ +from collections import OrderedDict + +import pytest +import torch +from torch.utils.data import TensorDataset + +from src.explainers.aggregators.self_influence import ( + get_self_influence_ranking, +) +from src.utils.explain_wrapper import explain +from src.utils.functions.similarities import dot_product_similarity + + +@pytest.mark.self_influence +@pytest.mark.parametrize( + "test_id, explain_kwargs", + [ + ( + "random_data", + {"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity}, + ), + ], +) +def test_self_influence_ranking(test_id, explain_kwargs, request): + model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) + X = torch.randn(100, 200) + rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,))) + + self_influence_rank = get_self_influence_ranking( + model=model, + model_id="0", + cache_dir="temp_captum", + training_data=rand_dataset, + explain_fn=explain, + explain_fn_kwargs=explain_kwargs, + ) + + assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort()) From e2e386aedb635faf9800a3fe0163491f959a0052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 14 Jun 2024 19:49:20 +0200 Subject: [PATCH 22/23] revert omitted call to base class initializer --- src/metrics/unnamed/top_k_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index cb872726..e5f70235 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -13,7 +13,7 @@ def __init__( *args, **kwargs, ): - # super().__init__(model, train_dataset, *args, **kwargs) + super().__init__(model, train_dataset, *args, **kwargs) self.top_k = top_k self.all_top_k_examples = torch.empty(0, top_k) From 0c013221076f6e35f76fce8485216a1474d8ebab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 14 Jun 2024 20:17:59 +0200 Subject: [PATCH 23/23] Fix explainer_wrapper test --- pytest.ini | 2 ++ src/explainers/aggregators/self_influence.py | 3 --- tests/utils/test_explain_wrapper.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytest.ini b/pytest.ini index d715cfff..903adf8d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,5 @@ markers = localization_metrics: localization_metrics unnamed_metrics: unnamed_metrics randomization_metrics: randomization_metrics + aggregators: aggregators + self_influence: self_influence diff --git a/src/explainers/aggregators/self_influence.py b/src/explainers/aggregators/self_influence.py index 336e93fd..96406e62 100644 --- a/src/explainers/aggregators/self_influence.py +++ b/src/explainers/aggregators/self_influence.py @@ -1,5 +1,4 @@ from typing import Optional -from warnings import warn import torch @@ -14,8 +13,6 @@ def get_self_influence_ranking( explain_fn: ExplainFunc, explain_fn_kwargs: Optional[dict] = None, ) -> torch.Tensor: - if "train_ids" not in explain_fn_kwargs: - warn("train_id is supplied to compute self-influences. Supplied indices will be ignored.") size = len(training_data) self_inf = torch.zeros((size,)) diff --git a/tests/utils/test_explain_wrapper.py b/tests/utils/test_explain_wrapper.py index a88c859d..1e03136c 100644 --- a/tests/utils/test_explain_wrapper.py +++ b/tests/utils/test_explain_wrapper.py @@ -30,9 +30,9 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met model, test_id, os.path.join("./cache", "test_id"), + method, dataset, test_tensor, - method, **method_kwargs, ) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"