Skip to content

Commit

Permalink
remove self-influence kwargs from everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 26, 2024
1 parent cf21d46 commit f4a6674
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 30 deletions.
5 changes: 2 additions & 3 deletions quanda/explainers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Sized, Union
from typing import List, Optional, Sized, Union

import torch

Expand Down Expand Up @@ -54,15 +54,14 @@ def _process_targets(self, targets: Optional[Union[List[int], torch.Tensor]]):
return targets

@cache_result
def self_influence(self, **kwargs: Any) -> torch.Tensor:
def self_influence(self, batch_size: int = 32) -> torch.Tensor:
"""
Base class implements computing self influences by explaining the train dataset one by one
:param batch_size:
:param kwargs:
:return:
"""
batch_size = kwargs.get("batch_size", 32)

# Pre-allcate memory for influences, because torch.cat is slow
influences = torch.empty((self.dataset_length,), device=self.device)
Expand Down
4 changes: 2 additions & 2 deletions quanda/explainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def self_influence_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
self_influence_kwargs: dict,
cache_dir: Optional[str] = None,
model_id: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
explainer = _init_explainer(
Expand All @@ -54,4 +54,4 @@ def self_influence_fn_from_explainer(
**kwargs,
)

return explainer.self_influence(**self_influence_kwargs)
return explainer.self_influence(batch_size=batch_size)
36 changes: 12 additions & 24 deletions quanda/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
from captum.influence import SimilarityInfluence, TracInCP # type: ignore
Expand Down Expand Up @@ -172,19 +172,16 @@ def captum_similarity_self_influence(
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
batch_size: Optional[int] = 32,
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {
"batch_size": batch_size,
}
return self_influence_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
self_influence_kwargs=self_influence_kwargs,
batch_size=batch_size,
**kwargs,
)

Expand Down Expand Up @@ -272,17 +269,15 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T
influence_scores = self.captum_explainer.influence(inputs=(test, targets))
return influence_scores

def self_influence(self, **kwargs: Any) -> torch.Tensor:
inputs_dataset = kwargs.get("inputs_dataset", None)
influence_scores = self.captum_explainer.self_influence(inputs_dataset=inputs_dataset)
def self_influence(self, batch_size: int = 32) -> torch.Tensor:
influence_scores = self.captum_explainer.self_influence(inputs_dataset=None)
return influence_scores


def captum_arnoldi_explain(
model: torch.nn.Module,
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
Expand All @@ -303,22 +298,18 @@ def captum_arnoldi_explain(
def captum_arnoldi_self_influence(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
inputs_dataset: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {
"inputs_dataset": inputs_dataset,
}
return self_influence_fn_from_explainer(
explainer_cls=CaptumArnoldi,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
self_influence_kwargs=self_influence_kwargs,
batch_size=batch_size,
**kwargs,
)

Expand Down Expand Up @@ -351,6 +342,7 @@ def __init__(
explainer_kwargs.pop(arg)
warnings.warn(f"{arg} is not supported by CaptumTraceInCP explainer. Ignoring the argument.")

self.outer_loop_by_checkpoints = explainer_kwargs.pop("outer_loop_by_checkpoints", False)
explainer_kwargs.update(
{
"model": model,
Expand Down Expand Up @@ -388,11 +380,9 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T
influence_scores = self.captum_explainer.influence(inputs=(test, targets))
return influence_scores

def self_influence(self, **kwargs: Any) -> torch.Tensor:
inputs = kwargs.get("inputs", None)
outer_loop_by_checkpoints = kwargs.get("outer_loop_by_checkpoints", False)
def self_influence(self, batch_size: int = 32) -> torch.Tensor:
influence_scores = self.captum_explainer.self_influence(
inputs=inputs, outer_loop_by_checkpoints=outer_loop_by_checkpoints
inputs=None, outer_loop_by_checkpoints=self.outer_loop_by_checkpoints
)
return influence_scores

Expand Down Expand Up @@ -421,19 +411,17 @@ def captum_tracincp_explain(
def captum_tracincp_self_influence(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
inputs: Optional[Union[Tuple[Any, ...], torch.utils.data.DataLoader]] = None,
outer_loop_by_checkpoints: bool = False,
model_id: Optional[str] = None,
cache_dir: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {"inputs": inputs, "outer_loop_by_checkpoints": outer_loop_by_checkpoints}
return self_influence_fn_from_explainer(
explainer_cls=CaptumTracInCP,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
self_influence_kwargs=self_influence_kwargs,
batch_size=batch_size,
**kwargs,
)
1 change: 0 additions & 1 deletion tests/explainers/wrappers/test_captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def test_captum_tracincp_self_influence(test_id, model, dataset, checkpoints, me
checkpoints=checkpoints,
checkpoints_load_func=get_load_state_dict_func("cpu"),
device="cpu",
outer_loop_by_checkpoints=True,
**method_kwargs,
)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"

0 comments on commit f4a6674

Please sign in to comment.