Skip to content

Commit

Permalink
sort explain fns
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 20, 2024
1 parent 950b8b8 commit 5cded75
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 97 deletions.
3 changes: 2 additions & 1 deletion src/explainers/aggregators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch


class BaseAggregator(ABC):
def __init__(self):
self.scores: torch.Tensor = None
self.scores: Optional[torch.Tensor] = None

@abstractmethod
def update(self, explanations: torch.Tensor):
Expand Down
89 changes: 0 additions & 89 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import torch

from src.explainers.wrappers.captum_influence import CaptumSimilarity


class ExplainFunc(Protocol):
def __call__(
Expand All @@ -19,90 +17,3 @@ def __call__(
device: Union[str, torch.device],
) -> torch.Tensor:
pass


def explain_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
init_kwargs: Optional[Dict] = {},
explain_kwargs: Optional[Dict] = {},
) -> torch.Tensor:
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.explain(test=test_tensor, **explain_kwargs)


def explainer_self_influence_interface(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.self_influence()


def captum_similarity_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
init_kwargs: Optional[Dict] = {},
explain_kwargs: Optional[Dict] = {},
) -> torch.Tensor:
return explain_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
test_tensor=test_tensor,
explanation_targets=explanation_targets,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
explain_kwargs=explain_kwargs,
)


def captum_similarity_self_influence_ranking(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
return explainer_self_influence_interface(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
)
56 changes: 56 additions & 0 deletions src/explainers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict, List, Optional, Union

import torch


def explain_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
targets: Optional[Union[List[int], torch.Tensor]] = None,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
) -> torch.Tensor:

init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.explain(test=test_tensor, targets=targets, **explain_kwargs)


def self_influence_fn_from_explainer(
explainer_cls: type,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
batch_size: Optional[int] = 32,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
) -> torch.Tensor:

init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.self_influence(batch_size=batch_size, **explain_kwargs)
57 changes: 56 additions & 1 deletion src/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from captum.influence import SimilarityInfluence

from src.explainers.base_explainer import BaseExplainer
from src.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
)
from src.utils.validation import validate_1d_tensor_or_int_list


Expand Down Expand Up @@ -123,3 +127,54 @@ def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.T
tda = torch.gather(topk_val, 1, inverted_idx)

return tda


def captum_similarity_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
) -> torch.Tensor:

init_kwargs = init_kwargs or {}
explain_kwargs = explain_kwargs or {}

return explain_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
test_tensor=test_tensor,
targets=explanation_targets,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
explain_kwargs=explain_kwargs,
)


def captum_similarity_self_influence(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:

init_kwargs = init_kwargs or {}

return self_influence_fn_from_explainer(
explainer_cls=CaptumSimilarity,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
)
6 changes: 4 additions & 2 deletions tests/explainers/test_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest
import torch

from src.explainers.functional import captum_similarity_explain
from src.explainers.wrappers.captum_influence import CaptumSimilarity
from src.explainers.wrappers.captum_influence import (
CaptumSimilarity,
captum_similarity_explain,
)
from src.utils.functions.similarities import cosine_similarity


Expand Down
8 changes: 5 additions & 3 deletions tests/explainers/test_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch
from torch.utils.data import TensorDataset

from src.explainers.functional import captum_similarity_self_influence_ranking
from src.explainers.wrappers.captum_influence import CaptumSimilarity
from src.explainers.wrappers.captum_influence import (
CaptumSimilarity,
captum_similarity_self_influence,
)
from src.utils.functions.similarities import dot_product_similarity


Expand All @@ -29,7 +31,7 @@ def test_self_influence(test_id, init_kwargs, request):
y = torch.randint(0, 10, (100,))
rand_dataset = TensorDataset(X, y)

self_influence_rank_functional = captum_similarity_self_influence_ranking(
self_influence_rank_functional = captum_similarity_self_influence(
model=model,
model_id="0",
cache_dir="temp_captum",
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_randomization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from src.explainers.functional import captum_similarity_explain
from src.explainers.wrappers.captum_influence import captum_similarity_explain
from src.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
Expand Down

0 comments on commit 5cded75

Please sign in to comment.