From 8ee589ff9b1bfe688e98a1ae59321d57d0f0d72c Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 4 Dec 2023 11:23:19 +0100 Subject: [PATCH 1/3] updated smprt and added tests --- quantus/metrics/randomisation/smooth_mprt.py | 12 +++++---- tests/metrics/test_randomisation_metrics.py | 28 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index cb62a6aa..96e76d1e 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -53,8 +53,8 @@ class SmoothMPRT(Metric): """ Implementation of the Smooth MPRT by Hedström et al., 2023. - The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT, - where the explanations are averaged over N noisy samples before the similarity between the original- + The Smooth Model Parameter Randomisation adds a "denoising" preprocessing step to the original MPRT, + where the explanations are averaged over N noisy samples before the similarity between the original- and fully random model's explanations is measured. References: @@ -350,7 +350,7 @@ def __call__( model.get_model(), x_full_dataset, y_full_dataset, - **kwargs, + **self.explain_func_kwargs, ) # Compute the similarity of explanations of the original model. @@ -384,14 +384,16 @@ def __call__( random_layer_model, x_full_dataset, y_full_dataset, - **kwargs, + **self.explain_func_kwargs, ) # Compute the similarity of explanations of the perturbed model. for a_batch, a_batch_perturbed in zip( self.generate_a_batches(a_full_dataset), a_perturbed_generator ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + for a_instance, a_instance_perturbed in zip( + a_batch, a_batch_perturbed + ): score = self.evaluate_instance( model=random_layer_model, x=None, diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index 95dd75e6..842635ae 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -3,6 +3,7 @@ import pytest from pytest_lazyfixture import lazy_fixture import numpy as np +from zennit import attribution as zattr from quantus.functions.explanation_func import explain from quantus.functions import complexity_func, n_bins_func @@ -745,6 +746,33 @@ def test_model_parameter_randomisation( }, {"exception": ValueError}, ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "shape": (8, 1, 28, 28), + "canonizer": None, + "composite": None, + "attributor": zattr.Gradient, + "xai_lib": "zennit", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), ], ) def test_smooth_model_parameter_randomisation( From 65c01301fd46f3f7c90426b9bc4d3e70d24f22f0 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 4 Dec 2023 11:33:20 +0100 Subject: [PATCH 2/3] added one test --- tests/metrics/test_randomisation_metrics.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index 842635ae..91c18e1d 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -773,6 +773,30 @@ def test_model_parameter_randomisation( }, {"min": -1.0, "max": 1.0}, ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "attributor": zattr.IntegratedGradients, + "xai_lib": "zennit", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), ], ) def test_smooth_model_parameter_randomisation( From c40d1c6964ee7bc63657baf56ddedce70e4b5f13 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 4 Dec 2023 13:29:20 +0100 Subject: [PATCH 3/3] added explain_func_kwargs to custom_preprocess --- quantus/metrics/randomisation/smooth_mprt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index 96e76d1e..d79317a2 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -546,7 +546,9 @@ def custom_preprocess( return None a_batch_chunks = [] - for a_chunk in self.generate_explanations(model, x_batch, y_batch): + for a_chunk in self.generate_explanations( + model, x_batch, y_batch, **{**kwargs, **self.explain_func_kwargs} + ): a_batch_chunks.extend(a_chunk) return dict(a_batch=np.asarray(a_batch_chunks))