From 85f3b4ba525e0ff57bbc9ee08c13e0227d37b03e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Sun, 9 Jun 2024 19:19:43 +0200 Subject: [PATCH] ExplanationsAggregator and get_self_influence_ranking --- src/utils/aggregators.py | 24 +++++++++++++++++ src/utils/common.py | 21 ++++++++++++++- src/utils/explain_wrapper.py | 12 +++++++++ src/utils/globalization/__init__.py | 0 src/utils/globalization/base.py | 11 -------- src/utils/globalization/from_explainer.py | 27 -------------------- src/utils/globalization/from_explanations.py | 8 ------ 7 files changed, 56 insertions(+), 47 deletions(-) create mode 100644 src/utils/aggregators.py delete mode 100644 src/utils/globalization/__init__.py delete mode 100644 src/utils/globalization/base.py delete mode 100644 src/utils/globalization/from_explainer.py delete mode 100644 src/utils/globalization/from_explanations.py diff --git a/src/utils/aggregators.py b/src/utils/aggregators.py new file mode 100644 index 00000000..df07e825 --- /dev/null +++ b/src/utils/aggregators.py @@ -0,0 +1,24 @@ +from abc import ABC + +import torch + + +class ExplanationsAggregator(ABC): + def __init__(self, training_size: int, *args, **kwargs): + self.scores = torch.zeros(training_size) + + def update(self, explanations: torch.Tensor): + raise NotImplementedError + + def get_global_ranking(self) -> torch.Tensor: + return self.scores.argsort() + + +class SumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.sum(dim=0) + + +class AbsSumAggregator(ExplanationsAggregator): + def update(self, explanations: torch.Tensor) -> torch.Tensor: + self.scores += explanations.abs().sum(dim=0) diff --git a/src/utils/common.py b/src/utils/common.py index 45e34a05..b5686ddf 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,8 +1,12 @@ import functools from functools import reduce -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, Optional import torch +import torch.utils +import torch.utils.data + +from utils.explain_wrapper import SelfInfluenceFunction def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any: @@ -22,3 +26,18 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) - func_kwargs = kwargs return functools.partial(func, **func_kwargs) + + +def get_self_influence_ranking( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + training_data: torch.utils.data.Dataset, + self_influence_fn: SelfInfluenceFunction, + self_influence_fn_kwargs: Optional[dict] = None, +) -> torch.Tensor: + size = len(training_data) + self_inf = torch.zeros((size,)) + for i, (x, y) in enumerate(training_data): + self_inf[i] = self_influence_fn(model, model_id, cache_dir, training_data, i) + return self_inf.argsort() diff --git a/src/utils/explain_wrapper.py b/src/utils/explain_wrapper.py index 03e4775e..367b6f5c 100644 --- a/src/utils/explain_wrapper.py +++ b/src/utils/explain_wrapper.py @@ -19,6 +19,18 @@ def __call__( pass +class SelfInfluenceFunction(Protocol): + def __call__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + id: int, + ) -> torch.Tensor: + pass + + def explain( model: torch.nn.Module, model_id: str, diff --git a/src/utils/globalization/__init__.py b/src/utils/globalization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/globalization/base.py b/src/utils/globalization/base.py deleted file mode 100644 index 947b5266..00000000 --- a/src/utils/globalization/base.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - - -class Globalization: - def __init__(self, training_dataset: torch.utils.data.Dataset, *args, **kwargs): - self.dataset = training_dataset - self.scores = torch.zeros((len(training_dataset))) - raise NotImplementedError - - def get_global_ranking(self): - return self.scores.argmax() diff --git a/src/utils/globalization/from_explainer.py b/src/utils/globalization/from_explainer.py deleted file mode 100644 index 13c03c70..00000000 --- a/src/utils/globalization/from_explainer.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Callable, Optional - -import torch - -from src.utils.common import make_func -from src.utils.globalization.base import Globalization - - -class GlobalizationFromSingleImageAttributor(Globalization): - def __init__( - self, - training_dataset: torch.utils.data.Dataset, - model: torch.nn.Module, - attributor_fn: Callable, - attributor_fn_kwargs: Optional[dict] = None, - ): - # why is it called attributor - super().__init__(training_dataset=training_dataset) - self.attributor_fn = make_func(func=attributor_fn, func_kwargs=attributor_fn_kwargs, model=self.model) - self.model = model - - def compute_self_influences(self): - for i, (x, _) in enumerate(self.training_dataset): - self.scores[i] = self.attributor_fn(datapoint=x) - - def update_self_influences(self, self_influences): - self.scores = self_influences diff --git a/src/utils/globalization/from_explanations.py b/src/utils/globalization/from_explanations.py deleted file mode 100644 index 8d027655..00000000 --- a/src/utils/globalization/from_explanations.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch - -from src.utils.globalization.base import Globalization - - -class GlobalizationFromExplanations(Globalization): - def update(self, explanations: torch.Tensor): - self.scores += explanations.abs().sum(dim=0)