diff --git a/quanda/explainers/base.py b/quanda/explainers/base.py index 2fc977d6..8130f416 100644 --- a/quanda/explainers/base.py +++ b/quanda/explainers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Sized, Union +from typing import List, Optional, Sized, Union import torch @@ -54,7 +54,7 @@ def _process_targets(self, targets: Optional[Union[List[int], torch.Tensor]]): return targets @cache_result - def self_influence(self, **kwargs: Any) -> torch.Tensor: + def self_influence(self, batch_size: int = 32) -> torch.Tensor: """ Base class implements computing self influences by explaining the train dataset one by one @@ -62,7 +62,6 @@ def self_influence(self, **kwargs: Any) -> torch.Tensor: :param kwargs: :return: """ - batch_size = kwargs.get("batch_size", 32) # Pre-allcate memory for influences, because torch.cat is slow influences = torch.empty((self.dataset_length,), device=self.device) diff --git a/quanda/explainers/utils.py b/quanda/explainers/utils.py index 5b5a572d..e1613932 100644 --- a/quanda/explainers/utils.py +++ b/quanda/explainers/utils.py @@ -40,9 +40,9 @@ def self_influence_fn_from_explainer( explainer_cls: type, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - self_influence_kwargs: dict, cache_dir: Optional[str] = None, model_id: Optional[str] = None, + batch_size: int = 32, **kwargs: Any, ) -> torch.Tensor: explainer = _init_explainer( @@ -54,4 +54,4 @@ def self_influence_fn_from_explainer( **kwargs, ) - return explainer.self_influence(**self_influence_kwargs) + return explainer.self_influence(batch_size=batch_size) diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py index 3690637b..0e9fe1bb 100644 --- a/quanda/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -1,7 +1,7 @@ import copy import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Union import torch from captum.influence import SimilarityInfluence, TracInCP # type: ignore @@ -172,19 +172,16 @@ def captum_similarity_self_influence( model_id: str, cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, - batch_size: Optional[int] = 32, + batch_size: int = 32, **kwargs: Any, ) -> torch.Tensor: - self_influence_kwargs = { - "batch_size": batch_size, - } return self_influence_fn_from_explainer( explainer_cls=CaptumSimilarity, model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, - self_influence_kwargs=self_influence_kwargs, + batch_size=batch_size, **kwargs, ) @@ -272,9 +269,8 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T influence_scores = self.captum_explainer.influence(inputs=(test, targets)) return influence_scores - def self_influence(self, **kwargs: Any) -> torch.Tensor: - inputs_dataset = kwargs.get("inputs_dataset", None) - influence_scores = self.captum_explainer.self_influence(inputs_dataset=inputs_dataset) + def self_influence(self, batch_size: int = 32) -> torch.Tensor: + influence_scores = self.captum_explainer.self_influence(inputs_dataset=None) return influence_scores @@ -282,7 +278,6 @@ def captum_arnoldi_explain( model: torch.nn.Module, test_tensor: torch.Tensor, train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, model_id: Optional[str] = None, cache_dir: Optional[str] = None, @@ -303,22 +298,18 @@ def captum_arnoldi_explain( def captum_arnoldi_self_influence( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - inputs_dataset: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None, model_id: Optional[str] = None, cache_dir: Optional[str] = None, + batch_size: int = 32, **kwargs: Any, ) -> torch.Tensor: - self_influence_kwargs = { - "inputs_dataset": inputs_dataset, - } return self_influence_fn_from_explainer( explainer_cls=CaptumArnoldi, model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, - self_influence_kwargs=self_influence_kwargs, + batch_size=batch_size, **kwargs, ) @@ -351,6 +342,7 @@ def __init__( explainer_kwargs.pop(arg) warnings.warn(f"{arg} is not supported by CaptumTraceInCP explainer. Ignoring the argument.") + self.outer_loop_by_checkpoints = explainer_kwargs.pop("outer_loop_by_checkpoints", False) explainer_kwargs.update( { "model": model, @@ -388,11 +380,9 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T influence_scores = self.captum_explainer.influence(inputs=(test, targets)) return influence_scores - def self_influence(self, **kwargs: Any) -> torch.Tensor: - inputs = kwargs.get("inputs", None) - outer_loop_by_checkpoints = kwargs.get("outer_loop_by_checkpoints", False) + def self_influence(self, batch_size: int = 32) -> torch.Tensor: influence_scores = self.captum_explainer.self_influence( - inputs=inputs, outer_loop_by_checkpoints=outer_loop_by_checkpoints + inputs=None, outer_loop_by_checkpoints=self.outer_loop_by_checkpoints ) return influence_scores @@ -421,19 +411,17 @@ def captum_tracincp_explain( def captum_tracincp_self_influence( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - inputs: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None, - outer_loop_by_checkpoints: bool = False, model_id: Optional[str] = None, cache_dir: Optional[str] = None, + batch_size: int = 32, **kwargs: Any, ) -> torch.Tensor: - self_influence_kwargs = {"inputs": inputs, "outer_loop_by_checkpoints": outer_loop_by_checkpoints} return self_influence_fn_from_explainer( explainer_cls=CaptumTracInCP, model=model, model_id=model_id, cache_dir=cache_dir, train_dataset=train_dataset, - self_influence_kwargs=self_influence_kwargs, + batch_size=batch_size, **kwargs, ) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index f9e9d3d5..fb916e21 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -473,7 +473,6 @@ def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, me checkpoints=checkpoints, checkpoints_load_func=get_load_state_dict_func("cpu"), device="cpu", - outer_loop_by_checkpoints=True, **method_kwargs, ) assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"