Skip to content

Commit

Permalink
Merge pull request #318 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…tiny-fixes

Added explain_func_kwargs to SmoothMPRT and zennit tests
  • Loading branch information
annahedstroem authored Dec 4, 2023
2 parents ac36e91 + c40d1c6 commit fcd8ded
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
16 changes: 10 additions & 6 deletions quantus/metrics/randomisation/smooth_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class SmoothMPRT(Metric):
"""
Implementation of the Smooth MPRT by Hedström et al., 2023.
The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT,
where the explanations are averaged over N noisy samples before the similarity between the original-
The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT,
where the explanations are averaged over N noisy samples before the similarity between the original-
and fully random model's explanations is measured.
References:
Expand Down Expand Up @@ -350,7 +350,7 @@ def __call__(
model.get_model(),
x_full_dataset,
y_full_dataset,
**kwargs,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the original model.
Expand Down Expand Up @@ -384,14 +384,16 @@ def __call__(
random_layer_model,
x_full_dataset,
y_full_dataset,
**kwargs,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
for a_instance, a_instance_perturbed in zip(
a_batch, a_batch_perturbed
):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
Expand Down Expand Up @@ -544,7 +546,9 @@ def custom_preprocess(
return None

a_batch_chunks = []
for a_chunk in self.generate_explanations(model, x_batch, y_batch):
for a_chunk in self.generate_explanations(
model, x_batch, y_batch, **{**kwargs, **self.explain_func_kwargs}
):
a_batch_chunks.extend(a_chunk)
return dict(a_batch=np.asarray(a_batch_chunks))

Expand Down
52 changes: 52 additions & 0 deletions tests/metrics/test_randomisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
from zennit import attribution as zattr

from quantus.functions.explanation_func import explain
from quantus.functions import complexity_func, n_bins_func
Expand Down Expand Up @@ -745,6 +746,57 @@ def test_model_parameter_randomisation(
},
{"exception": ValueError},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
{
"init": {
"layer_order": "independent",
"similarity_func": correlation_spearman,
"normalise": True,
"abs": True,
"disable_warnings": True,
"return_average_correlation": False,
"return_last_correlation": True,
"skip_layers": True,
},
"call": {
"explain_func": explain,
"explain_func_kwargs": {
"shape": (8, 1, 28, 28),
"canonizer": None,
"composite": None,
"attributor": zattr.Gradient,
"xai_lib": "zennit",
},
},
},
{"min": -1.0, "max": 1.0},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
{
"init": {
"layer_order": "independent",
"similarity_func": correlation_spearman,
"normalise": True,
"abs": True,
"disable_warnings": True,
"return_average_correlation": False,
"return_last_correlation": True,
"skip_layers": True,
},
"call": {
"explain_func": explain,
"explain_func_kwargs": {
"attributor": zattr.IntegratedGradients,
"xai_lib": "zennit",
},
},
},
{"min": -1.0, "max": 1.0},
),
],
)
def test_smooth_model_parameter_randomisation(
Expand Down

0 comments on commit fcd8ded

Please sign in to comment.