Skip to content

Commit

Permalink
aggregator and self influence tests
Browse files Browse the repository at this point in the history
gumityolcu committed Jun 14, 2024
1 parent 7254e17 commit 75fab05
Showing 2 changed files with 85 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/explainers/aggregators/test_aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch

from src.explainers.aggregators.aggregators import (
AbsSumAggregator,
SumAggregator,
)


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_sum_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = SumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.sum(dim=0).argsort())


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_abs_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = AbsSumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort())
38 changes: 38 additions & 0 deletions tests/explainers/aggregators/test_self_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections import OrderedDict

import pytest
import torch
from torch.utils.data import TensorDataset

from src.explainers.aggregators.self_influence import (
get_self_influence_ranking,
)
from src.utils.explain_wrapper import explain
from src.utils.functions.similarities import dot_product_similarity


@pytest.mark.self_influence
@pytest.mark.parametrize(
"test_id, explain_kwargs",
[
(
"random_data",
{"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity},
),
],
)
def test_self_influence_ranking(test_id, explain_kwargs, request):
model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())]))
X = torch.randn(100, 200)
rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,)))

self_influence_rank = get_self_influence_ranking(
model=model,
model_id="0",
cache_dir="temp_captum",
training_data=rand_dataset,
explain_fn=explain,
explain_fn_kwargs=explain_kwargs,
)

assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort())

0 comments on commit 75fab05

Please sign in to comment.