Skip to content

Commit

Permalink
fix and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 18, 2024
1 parent 09c73e7 commit c5ddc04
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 49 deletions.
41 changes: 31 additions & 10 deletions tests/explainers/test_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from torch.utils.data import TensorDataset

from src.explainers.self_influence import get_self_influence_ranking
from src.utils.explain_wrapper import explain
from src.explainers.captum.similarity import CaptumSimilarityExplainer
from src.explainers.functional import captum_similarity_self_influence_ranking
from src.utils.functions.similarities import dot_product_similarity


Expand All @@ -15,22 +15,43 @@
[
(
"random_data",
{"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity},
{"layer": "identity", "similarity_metric": dot_product_similarity},
),
],
)
def test_self_influence_ranking(test_id, explain_kwargs, request):
def test_self_influence(test_id, init_kwargs, request):
model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())]))
# X=torch.randn(1,200)
import os
import shutil

os.mkdir("temp_captum")
os.mkdir("temp_captum2")
torch.random.manual_seed(42)
X = torch.randn(100, 200)
rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,)))
# rand_dataset = TensorDataset(X,torch.randint(0,10,(1,)))
y = torch.randint(0, 10, (100,))
rand_dataset = TensorDataset(X, y)
init_kwargs = {"layer": "identity", "similarity_metric": dot_product_similarity}

self_influence_rank = get_self_influence_ranking(
self_influence_rank_functional = captum_similarity_self_influence_ranking(
model=model,
model_id="0",
cache_dir="temp_captum",
training_data=rand_dataset,
explain_fn=explain,
explain_fn_kwargs=explain_kwargs,
train_dataset=rand_dataset,
init_kwargs=init_kwargs,
device="cpu",
)

explainer_obj = CaptumSimilarityExplainer(
model=model, model_id="1", cache_dir="temp_captum2", train_dataset=rand_dataset, device="cpu", **init_kwargs
)
self_influence_rank_stateful = explainer_obj.self_influence_ranking()

if os.path.isdir("temp_captum2"):
shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2"))
if os.path.isdir("temp_captum"):
shutil.rmtree(os.path.join(os.getcwd(), "temp_captum"))

assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort())
assert torch.allclose(self_influence_rank_functional, torch.linalg.norm(X, dim=-1).argsort())
assert torch.allclose(self_influence_rank_functional, self_influence_rank_stateful)
2 changes: 1 addition & 1 deletion tests/metrics/test_unnamed_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@pytest.mark.parametrize(
"test_id, model, dataset, top_k, batch_size, explanations, expected_score",
[
("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1", 8),
("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1", 7),
],
)
def test_top_k_overlap_metrics(test_id, model, dataset, top_k, batch_size, explanations, expected_score, request):
Expand Down
38 changes: 0 additions & 38 deletions tests/utils/test_explain_wrapper.py

This file was deleted.

0 comments on commit c5ddc04

Please sign in to comment.