diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index cb62a6aa..d79317a2 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -53,8 +53,8 @@ class SmoothMPRT(Metric): """ Implementation of the Smooth MPRT by Hedström et al., 2023. - The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT, - where the explanations are averaged over N noisy samples before the similarity between the original- + The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT, + where the explanations are averaged over N noisy samples before the similarity between the original- and fully random model's explanations is measured. References: @@ -350,7 +350,7 @@ def __call__( model.get_model(), x_full_dataset, y_full_dataset, - **kwargs, + **self.explain_func_kwargs, ) # Compute the similarity of explanations of the original model. @@ -384,14 +384,16 @@ def __call__( random_layer_model, x_full_dataset, y_full_dataset, - **kwargs, + **self.explain_func_kwargs, ) # Compute the similarity of explanations of the perturbed model. for a_batch, a_batch_perturbed in zip( self.generate_a_batches(a_full_dataset), a_perturbed_generator ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + for a_instance, a_instance_perturbed in zip( + a_batch, a_batch_perturbed + ): score = self.evaluate_instance( model=random_layer_model, x=None, @@ -544,7 +546,9 @@ def custom_preprocess( return None a_batch_chunks = [] - for a_chunk in self.generate_explanations(model, x_batch, y_batch): + for a_chunk in self.generate_explanations( + model, x_batch, y_batch, **{**kwargs, **self.explain_func_kwargs} + ): a_batch_chunks.extend(a_chunk) return dict(a_batch=np.asarray(a_batch_chunks)) diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index 95dd75e6..91c18e1d 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -3,6 +3,7 @@ import pytest from pytest_lazyfixture import lazy_fixture import numpy as np +from zennit import attribution as zattr from quantus.functions.explanation_func import explain from quantus.functions import complexity_func, n_bins_func @@ -745,6 +746,57 @@ def test_model_parameter_randomisation( }, {"exception": ValueError}, ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "shape": (8, 1, 28, 28), + "canonizer": None, + "composite": None, + "attributor": zattr.Gradient, + "xai_lib": "zennit", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "attributor": zattr.IntegratedGradients, + "xai_lib": "zennit", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), ], ) def test_smooth_model_parameter_randomisation(