From f4a66741ef967cbdb69ce2e059aabcae12764f09 Mon Sep 17 00:00:00 2001
From: dilyabareeva <dilyabareeva@gmail.com>
Date: Mon, 26 Aug 2024 11:31:29 +0200
Subject: [PATCH] remove self-influence kwargs from everywhere

---
 quanda/explainers/base.py                     |  5 ++-
 quanda/explainers/utils.py                    |  4 +--
 .../explainers/wrappers/captum_influence.py   | 36 +++++++------------
 .../wrappers/test_captum_influence.py         |  1 -
 4 files changed, 16 insertions(+), 30 deletions(-)

diff --git a/quanda/explainers/base.py b/quanda/explainers/base.py
index 2fc977d6..8130f416 100644
--- a/quanda/explainers/base.py
+++ b/quanda/explainers/base.py
@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Sized, Union
+from typing import List, Optional, Sized, Union
 
 import torch
 
@@ -54,7 +54,7 @@ def _process_targets(self, targets: Optional[Union[List[int], torch.Tensor]]):
         return targets
 
     @cache_result
-    def self_influence(self, **kwargs: Any) -> torch.Tensor:
+    def self_influence(self, batch_size: int = 32) -> torch.Tensor:
         """
         Base class implements computing self influences by explaining the train dataset one by one
 
@@ -62,7 +62,6 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor:
         :param kwargs:
         :return:
         """
-        batch_size = kwargs.get("batch_size", 32)
 
         # Pre-allcate memory for influences, because torch.cat is slow
         influences = torch.empty((self.dataset_length,), device=self.device)
diff --git a/quanda/explainers/utils.py b/quanda/explainers/utils.py
index 5b5a572d..e1613932 100644
--- a/quanda/explainers/utils.py
+++ b/quanda/explainers/utils.py
@@ -40,9 +40,9 @@ def self_influence_fn_from_explainer(
     explainer_cls: type,
     model: torch.nn.Module,
     train_dataset: torch.utils.data.Dataset,
-    self_influence_kwargs: dict,
     cache_dir: Optional[str] = None,
     model_id: Optional[str] = None,
+    batch_size: int = 32,
     **kwargs: Any,
 ) -> torch.Tensor:
     explainer = _init_explainer(
@@ -54,4 +54,4 @@ def self_influence_fn_from_explainer(
         **kwargs,
     )
 
-    return explainer.self_influence(**self_influence_kwargs)
+    return explainer.self_influence(batch_size=batch_size)
diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py
index 3690637b..0e9fe1bb 100644
--- a/quanda/explainers/wrappers/captum_influence.py
+++ b/quanda/explainers/wrappers/captum_influence.py
@@ -1,7 +1,7 @@
 import copy
 import warnings
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
+from typing import Any, Callable, Iterator, List, Optional, Union
 
 import torch
 from captum.influence import SimilarityInfluence, TracInCP  # type: ignore
@@ -172,19 +172,16 @@ def captum_similarity_self_influence(
     model_id: str,
     cache_dir: Optional[str],
     train_dataset: torch.utils.data.Dataset,
-    batch_size: Optional[int] = 32,
+    batch_size: int = 32,
     **kwargs: Any,
 ) -> torch.Tensor:
-    self_influence_kwargs = {
-        "batch_size": batch_size,
-    }
     return self_influence_fn_from_explainer(
         explainer_cls=CaptumSimilarity,
         model=model,
         model_id=model_id,
         cache_dir=cache_dir,
         train_dataset=train_dataset,
-        self_influence_kwargs=self_influence_kwargs,
+        batch_size=batch_size,
         **kwargs,
     )
 
@@ -272,9 +269,8 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T
         influence_scores = self.captum_explainer.influence(inputs=(test, targets))
         return influence_scores
 
-    def self_influence(self, **kwargs: Any) -> torch.Tensor:
-        inputs_dataset = kwargs.get("inputs_dataset", None)
-        influence_scores = self.captum_explainer.self_influence(inputs_dataset=inputs_dataset)
+    def self_influence(self, batch_size: int = 32) -> torch.Tensor:
+        influence_scores = self.captum_explainer.self_influence(inputs_dataset=None)
         return influence_scores
 
 
@@ -282,7 +278,6 @@ def captum_arnoldi_explain(
     model: torch.nn.Module,
     test_tensor: torch.Tensor,
     train_dataset: torch.utils.data.Dataset,
-    device: Union[str, torch.device],
     explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
     model_id: Optional[str] = None,
     cache_dir: Optional[str] = None,
@@ -303,22 +298,18 @@ def captum_arnoldi_explain(
 def captum_arnoldi_self_influence(
     model: torch.nn.Module,
     train_dataset: torch.utils.data.Dataset,
-    device: Union[str, torch.device],
-    inputs_dataset: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
     model_id: Optional[str] = None,
     cache_dir: Optional[str] = None,
+    batch_size: int = 32,
     **kwargs: Any,
 ) -> torch.Tensor:
-    self_influence_kwargs = {
-        "inputs_dataset": inputs_dataset,
-    }
     return self_influence_fn_from_explainer(
         explainer_cls=CaptumArnoldi,
         model=model,
         model_id=model_id,
         cache_dir=cache_dir,
         train_dataset=train_dataset,
-        self_influence_kwargs=self_influence_kwargs,
+        batch_size=batch_size,
         **kwargs,
     )
 
@@ -351,6 +342,7 @@ def __init__(
                 explainer_kwargs.pop(arg)
                 warnings.warn(f"{arg} is not supported by CaptumTraceInCP explainer. Ignoring the argument.")
 
+        self.outer_loop_by_checkpoints = explainer_kwargs.pop("outer_loop_by_checkpoints", False)
         explainer_kwargs.update(
             {
                 "model": model,
@@ -388,11 +380,9 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T
         influence_scores = self.captum_explainer.influence(inputs=(test, targets))
         return influence_scores
 
-    def self_influence(self, **kwargs: Any) -> torch.Tensor:
-        inputs = kwargs.get("inputs", None)
-        outer_loop_by_checkpoints = kwargs.get("outer_loop_by_checkpoints", False)
+    def self_influence(self, batch_size: int = 32) -> torch.Tensor:
         influence_scores = self.captum_explainer.self_influence(
-            inputs=inputs, outer_loop_by_checkpoints=outer_loop_by_checkpoints
+            inputs=None, outer_loop_by_checkpoints=self.outer_loop_by_checkpoints
         )
         return influence_scores
 
@@ -421,19 +411,17 @@ def captum_tracincp_explain(
 def captum_tracincp_self_influence(
     model: torch.nn.Module,
     train_dataset: torch.utils.data.Dataset,
-    inputs: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
-    outer_loop_by_checkpoints: bool = False,
     model_id: Optional[str] = None,
     cache_dir: Optional[str] = None,
+    batch_size: int = 32,
     **kwargs: Any,
 ) -> torch.Tensor:
-    self_influence_kwargs = {"inputs": inputs, "outer_loop_by_checkpoints": outer_loop_by_checkpoints}
     return self_influence_fn_from_explainer(
         explainer_cls=CaptumTracInCP,
         model=model,
         model_id=model_id,
         cache_dir=cache_dir,
         train_dataset=train_dataset,
-        self_influence_kwargs=self_influence_kwargs,
+        batch_size=batch_size,
         **kwargs,
     )
diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py
index f9e9d3d5..fb916e21 100644
--- a/tests/explainers/wrappers/test_captum_influence.py
+++ b/tests/explainers/wrappers/test_captum_influence.py
@@ -473,7 +473,6 @@ def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, me
         checkpoints=checkpoints,
         checkpoints_load_func=get_load_state_dict_func("cpu"),
         device="cpu",
-        outer_loop_by_checkpoints=True,
         **method_kwargs,
     )
     assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"