Skip to content

Commit

Permalink
updated smprt and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Dec 4, 2023
1 parent ac36e91 commit 8ee589f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
12 changes: 7 additions & 5 deletions quantus/metrics/randomisation/smooth_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class SmoothMPRT(Metric):
"""
Implementation of the Smooth MPRT by Hedström et al., 2023.
The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT,
where the explanations are averaged over N noisy samples before the similarity between the original-
The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT,
where the explanations are averaged over N noisy samples before the similarity between the original-
and fully random model's explanations is measured.
References:
Expand Down Expand Up @@ -350,7 +350,7 @@ def __call__(
model.get_model(),
x_full_dataset,
y_full_dataset,
**kwargs,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the original model.
Expand Down Expand Up @@ -384,14 +384,16 @@ def __call__(
random_layer_model,
x_full_dataset,
y_full_dataset,
**kwargs,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
for a_instance, a_instance_perturbed in zip(
a_batch, a_batch_perturbed
):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
Expand Down
28 changes: 28 additions & 0 deletions tests/metrics/test_randomisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
from zennit import attribution as zattr

from quantus.functions.explanation_func import explain
from quantus.functions import complexity_func, n_bins_func
Expand Down Expand Up @@ -745,6 +746,33 @@ def test_model_parameter_randomisation(
},
{"exception": ValueError},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
{
"init": {
"layer_order": "independent",
"similarity_func": correlation_spearman,
"normalise": True,
"abs": True,
"disable_warnings": True,
"return_average_correlation": False,
"return_last_correlation": True,
"skip_layers": True,
},
"call": {
"explain_func": explain,
"explain_func_kwargs": {
"shape": (8, 1, 28, 28),
"canonizer": None,
"composite": None,
"attributor": zattr.Gradient,
"xai_lib": "zennit",
},
},
},
{"min": -1.0, "max": 1.0},
),
],
)
def test_smooth_model_parameter_randomisation(
Expand Down

0 comments on commit 8ee589f

Please sign in to comment.