Skip to content

Commit

Permalink
Add missing files
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 18, 2024
1 parent c5ddc04 commit 779c0ff
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
109 changes: 109 additions & 0 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Dict, List, Optional, Protocol, Union

import torch

from src.explainers.base import Explainer
from src.explainers.captum.similarity import CaptumSimilarityExplainer


class ExplainFunc(Protocol):
def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
explain_kwargs: Dict,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
pass


def explainer_functional_interface(
explainer_cls: Explainer,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
init_kwargs: Optional[Dict] = {},
explain_kwargs: Optional[Dict] = {},
) -> torch.Tensor:
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.explain(test=test_tensor, **explain_kwargs)


def explainer_self_influence_interface(
explainer_cls: Explainer,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
explainer = explainer_cls(
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
**init_kwargs,
)
return explainer.self_influence_ranking()


def captum_similarity_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
init_kwargs: Optional[Dict] = {},
explain_kwargs: Optional[Dict] = {},
) -> torch.Tensor:
return explainer_functional_interface(
explainer_cls=CaptumSimilarityExplainer,
model=model,
model_id=model_id,
cache_dir=cache_dir,
test_tensor=test_tensor,
explanation_targets=explanation_targets,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
explain_kwargs=explain_kwargs,
)


def captum_similarity_self_influence_ranking(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
return explainer_self_influence_interface(
explainer_cls=CaptumSimilarityExplainer,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
init_kwargs=init_kwargs,
)
77 changes: 77 additions & 0 deletions tests/explainers/test_explainers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os

import pytest
import torch

from src.explainers.captum.similarity import CaptumSimilarityExplainer
from src.explainers.functional import captum_similarity_explain
from src.utils.functions.similarities import cosine_similarity


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, method, method_kwargs, explanations",
[
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
"SimilarityInfluence",
{"layer": "relu_4"},
"load_mnist_explanations_1",
),
],
)
def test_explain_functional(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
test_labels = request.getfixturevalue(test_labels)
explanations_exp = request.getfixturevalue(explanations)
explanations = captum_similarity_explain(
model,
"test_id",
os.path.join("./cache", "test_id"),
test_tensor,
test_labels,
dataset,
device="cpu",
init_kwargs=method_kwargs,
)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, method, method_kwargs, explanations",
[
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
"SimilarityInfluence",
{"layer": "relu_4", "similarity_metric": cosine_similarity},
"load_mnist_explanations_1",
),
],
)
def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
test_labels = request.getfixturevalue(test_labels)
explanations_exp = request.getfixturevalue(explanations)
explainer = CaptumSimilarityExplainer(
model=model,
model_id="test_id",
cache_dir=os.path.join("./cache", "test_id"),
train_dataset=dataset,
device="cpu",
**method_kwargs,
)
explanations = explainer.explain(test_tensor, test_labels)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"

0 comments on commit 779c0ff

Please sign in to comment.