From 171bfbe0f73aa7ca6ae53dff1bad37ae2ff1f408 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 20 Nov 2023 18:06:41 +0100 Subject: [PATCH 01/11] updates to MPRT test --- quantus/helpers/constants.py | 3 +- quantus/metrics/randomisation/__init__.py | 3 + ...efficient_model_parameter_randomisation.py | 413 ++++++++++++++ .../model_parameter_randomisation.py | 199 +++++-- quantus/metrics/randomisation/mprt.py | 504 ++++++++++++++++++ tests/metrics/test_randomisation_metrics.py | 96 +++- 6 files changed, 1159 insertions(+), 59 deletions(-) create mode 100644 quantus/metrics/randomisation/efficient_model_parameter_randomisation.py create mode 100644 quantus/metrics/randomisation/mprt.py diff --git a/quantus/helpers/constants.py b/quantus/helpers/constants.py index 46fe0edf2..a785868ca 100644 --- a/quantus/helpers/constants.py +++ b/quantus/helpers/constants.py @@ -61,7 +61,8 @@ "Effective Complexity": EffectiveComplexity, }, "Randomisation": { - "Model Parameter Randomisation": ModelParameterRandomisation, + "Model Parameter Randomisation Test": ModelParameterRandomisation, + "Efficient Model Parameter Randomisation Test": EfficientModelParameterRandomisation, "Random Logit": RandomLogit, }, "Axiomatic": { diff --git a/quantus/metrics/randomisation/__init__.py b/quantus/metrics/randomisation/__init__.py index c73f230c8..690262860 100644 --- a/quantus/metrics/randomisation/__init__.py +++ b/quantus/metrics/randomisation/__init__.py @@ -7,4 +7,7 @@ from quantus.metrics.randomisation.model_parameter_randomisation import ( ModelParameterRandomisation, ) +from quantus.metrics.randomisation.efficient_model_parameter_randomisation import ( + EfficientModelParameterRandomisation, +) from quantus.metrics.randomisation.random_logit import RandomLogit diff --git a/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py b/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py new file mode 100644 index 000000000..b78ff1ea6 --- /dev/null +++ b/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py @@ -0,0 +1,413 @@ +"""This module contains the implementation of the Model Parameter Sensitivity metric.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +import sys +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Union, + Generator, +) + + +import numpy as np +from tqdm.auto import tqdm +from sklearn.utils import gen_batches + +from quantus.functions.similarity_func import correlation_spearman +from quantus.helpers import asserts, warn +from quantus.helpers.enums import ( + DataType, + EvaluationCategory, + ModelType, + ScoreDirection, +) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +@final +class EfficientModelParameterRandomisation(Metric): + """ + Implementation of the Efficient Model Parameter Randomization Method by Hedström et. al., 2023. + + The Efficient Model Parameter Randomization measures replaces the layer-by-layer pairwise comparison between e and ˆe of MPRT by instead computing the relative rise in explanation complexity using only two model states, i.e., the original- and fully randomised model version + + References: + 1) Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair the Model Parameter Randomisation Test." XAI in Action: Past, Present, and Future Applications. 2023. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Efficient Model Parameter Randomisation" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.LOWER + evaluation_category = EvaluationCategory.RANDOMISATION + + def __init__( + self, + similarity_func: Optional[Callable] = None, + layer_order: str = "independent", + seed: int = 42, + return_sample_correlation: bool = False, + abs: bool = True, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input, default=correlation_spearman. + layer_order: string + Indicated whether the model is randomized cascadingly or independently. + Set order=top_down for cascading randomization, set order=independent for independent randomization, + default="independent". + seed: integer + Seed used for the random generator, default=42. + return_sample_correlation: boolean + Indicates whether return one float per sample, representing the average + correlation coefficient across the layers for that sample. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=True. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if similarity_func is None: + similarity_func = correlation_spearman + self.similarity_func = similarity_func + self.layer_order = layer_order + self.seed = seed + self.return_sample_correlation = return_sample_correlation + + # Results are returned/saved as a dictionary not like in the super-class as a list. + self.evaluation_scores = {} + + # Asserts and warnings. + asserts.assert_layer_order(layer_order=self.layer_order) + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "similarity metric 'similarity_func' and the order of " + "the layer randomisation 'layer_order'" + ), + citation=( + "Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., and Kim, B. " + "'Sanity Checks for Saliency Maps.' arXiv preprint," + " arXiv:1810.073292v3 (2018)" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + **kwargs, + ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + The content of evaluation_scores will be appended to all_evaluation_scores (list) at the end of + the evaluation call. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + """ + + # Run deprecation warnings. + warn.deprecation_warnings(kwargs) + warn.check_kwargs(kwargs) + self.batch_size = batch_size + data = self.general_preprocess( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + model_predict_kwargs=model_predict_kwargs, + softmax=softmax, + device=device, + ) + model: ModelInterface = data["model"] # type: ignore + # Here _batch refers to full dataset. + x_full_dataset = data["x_batch"] + y_full_dataset = data["y_batch"] + a_full_dataset = data["a_batch"] + # Results are returned/saved as a dictionary not as a list as in the super-class. + self.evaluation_scores = {} + + # Get number of iterations from number of layers. + n_layers = model.random_layer_generator_length + pbar = tqdm( + total=n_layers * len(x_full_dataset), disable=not self.display_progressbar + ) + if self.display_progressbar: + # Set property to False, so we display only 1 pbar. + self._display_progressbar = False + + def generate_y_batches(): + for batch in gen_batches(len(a_full_dataset), batch_size): + yield a_full_dataset[batch.start : batch.stop] + + with pbar as pbar: + for layer_name, random_layer_model in model.get_random_layer_generator( + order=self.layer_order, seed=self.seed + ): + pbar.desc = layer_name + + similarity_scores = [] + # Generate explanations on modified model in batches + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size + ) + + for a_batch, a_batch_perturbed in zip( + generate_y_batches(), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip( + a_batch, a_batch_perturbed + ): + result = self.similarity_func( + a_instance_perturbed.flatten(), a_instance.flatten() + ) + similarity_scores.append(result) + pbar.update(1) + # Save similarity scores in a result dictionary. + self.evaluation_scores[layer_name] = similarity_scores + + if self.return_sample_correlation: + self.evaluation_scores = self.compute_correlation_per_sample() + + if self.return_aggregate: + assert self.return_sample_correlation, ( + "You must set 'return_average_correlation_per_sample'" + " to True in order to compute te aggregat" + ) + self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + + self.all_evaluation_scores.append(self.evaluation_scores) + + return self.evaluation_scores + + def compute_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + assert isinstance(self.evaluation_scores, dict), ( + "To compute the average correlation coefficient per sample for " + "Model Parameter Randomisation Test, 'last_result' " + "must be of type dict." + ) + layer_length = len( + self.evaluation_scores[list(self.evaluation_scores.keys())[0]] + ) + results: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in results: + for layer in self.evaluation_scores: + results[sample].append(float(self.evaluation_scores[layer][sample])) + results[sample] = np.mean(results[sample]) + + corr_coeffs = list(results.values()) + + return corr_coeffs + + def custom_preprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray], + **kwargs, + ) -> Optional[Dict[str, np.ndarray]]: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + kwargs: + Unused. + Returns + ------- + None + """ + # Additional explain_func assert, as the one in general_preprocess() + # won't be executed when a_batch != None. + asserts.assert_explain_func(explain_func=self.explain_func) + if a_batch is not None: + # Just to silence mypy warnings + return None + + a_batch_chunks = [] + for a_chunk in self.generate_explanations( + model, x_batch, y_batch, self.batch_size + ): + a_batch_chunks.extend(a_chunk) + return dict(a_batch=np.asarray(a_batch_chunks)) + + def generate_explanations( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + batch_size: int, + ) -> Generator[np.ndarray, None, None]: + """Iterate over dataset in batches and generate explanations for complete dataset""" + for i in gen_batches(len(x_batch), batch_size): + x = x_batch[i.start : i.stop] + y = y_batch[i.start : i.stop] + a = self.explain_batch(model, x, y) + yield a + + def evaluate_batch(self, *args, **kwargs): + raise RuntimeError( + "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + ) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 9abb99a61..de78886c4 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -1,4 +1,4 @@ -"""This module contains the implementation of the Model Parameter Sensitivity metric.""" +"""This module contains the implementation of the Model Parameter Randomisation Test metric.""" # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. @@ -15,7 +15,9 @@ List, Optional, Union, + Tuple, Generator, + Iterable, ) @@ -43,9 +45,9 @@ @final class ModelParameterRandomisation(Metric): """ - Implementation of the Model Parameter Randomization Method by Adebayo et. al., 2018. + Implementation of the Model Parameter Randomisation Method by Adebayo et. al., 2018. - The Model Parameter Randomization measures the distance between the original attribution and a newly computed + The Model Parameter Randomisation measures the distance between the original attribution and a newly computed attribution throughout the process of cascadingly/independently randomizing the model parameters of one layer at a time. @@ -64,7 +66,7 @@ class ModelParameterRandomisation(Metric): - evaluation_category: What property/ explanation quality that this metric measures. """ - name = "Model Parameter Randomisation" + name = "Model Parameter Randomisation Test" data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} model_applicability = {ModelType.TORCH, ModelType.TF} score_direction = ScoreDirection.LOWER @@ -73,9 +75,11 @@ class ModelParameterRandomisation(Metric): def __init__( self, similarity_func: Optional[Callable] = None, - layer_order: str = "independent", + layer_order: str = "top_down", seed: int = 42, - return_sample_correlation: bool = False, + return_average_correlation: bool = False, + return_last_correlation: bool = False, + skip_layers: bool = False, abs: bool = True, normalise: bool = True, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, @@ -98,9 +102,15 @@ def __init__( default="independent". seed: integer Seed used for the random generator, default=42. - return_sample_correlation: boolean - Indicates whether return one float per sample, representing the average - correlation coefficient across the layers for that sample. + return_average_correlation: boolean + Indicates whether to return one float per sample, computing the average + correlation coefficient across the layers for a given sample. + return_last_correlation: boolean + Indicates whether to return one float per sample, computing the explanation + correlation coefficient for the full model randomisation (not layer-wise) of a sample. + skip_layers: boolean + Indicates if explanation similarity should be computed only once; between the + original and fully randomised model, instead of in a layer-by-layer basis. abs: boolean Indicates whether absolute operation is applied on the attribution, default=True. normalise: boolean @@ -143,12 +153,28 @@ def __init__( self.similarity_func = similarity_func self.layer_order = layer_order self.seed = seed - self.return_sample_correlation = return_sample_correlation + self.return_average_correlation = return_average_correlation + self.return_last_correlation = return_last_correlation + self.skip_layers = skip_layers # Results are returned/saved as a dictionary not like in the super-class as a list. self.evaluation_scores = {} + # TODO. ... renaming warning. + # default values change. + # Asserts and warnings. + if self.return_average_correlation and self.return_last_correlation: + raise ValueError( + f"Both 'return_average_correlation' and 'return_last_correlation' cannot be set to 'True'. " + f"Set both to 'False' or one of the attributes to 'True'." + ) + if self.return_average_correlation and self.skip_layers: + raise ValueError( + f"Both 'return_average_correlation' and 'skip_layers' cannot be set to 'True'. " + f"You need to calculate the explanation correlation at all layers in order " + f"to compute the average correlation coefficient on all layers." + ) asserts.assert_layer_order(layer_order=self.layer_order) if not self.disable_warnings: warn.warn_parameterisation( @@ -280,6 +306,7 @@ def __call__( x_full_dataset = data["x_batch"] y_full_dataset = data["y_batch"] a_full_dataset = data["a_batch"] + # Results are returned/saved as a dictionary not as a list as in the super-class. self.evaluation_scores = {} @@ -297,38 +324,76 @@ def generate_y_batches(): yield a_full_dataset[batch.start : batch.stop] with pbar as pbar: - for layer_name, random_layer_model in model.get_random_layer_generator( - order=self.layer_order, seed=self.seed + for l_ix, (layer_name, random_layer_model) in enumerate( + model.get_random_layer_generator(order=self.layer_order, seed=self.seed) ): pbar.desc = layer_name similarity_scores = [] - # Generate explanations on modified model in batches - a_perturbed_generator = self.generate_explanations( - random_layer_model, x_full_dataset, y_full_dataset, batch_size - ) - - for a_batch, a_batch_perturbed in zip( - generate_y_batches(), a_perturbed_generator - ): - for a_instance, a_instance_perturbed in zip( - a_batch, a_batch_perturbed + + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < n_layers: + continue + + if l_ix == 0: + # Generate explanations on modified model in batches + a_original_generator = self.generate_explanations( + model.get_model(), x_full_dataset, y_full_dataset, batch_size + ) + + for a_batch, a_batch_original in zip( + generate_y_batches(), a_original_generator ): - result = self.similarity_func( - a_instance_perturbed.flatten(), a_instance.flatten() - ) - similarity_scores.append(result) - pbar.update(1) - # Save similarity scores in a result dictionary. - self.evaluation_scores[layer_name] = similarity_scores + for a_instance, a_instance_original in zip( + a_batch, a_batch_original + ): + score = self.evaluate_instance( + model=model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_original, + ) + similarity_scores.append(score) + pbar.update(1) + + # Save similarity scores in a result dictionary. + self.evaluation_scores["original"] = similarity_scores + + # Generate explanations on modified model in batches + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size + ) - if self.return_sample_correlation: - self.evaluation_scores = self.compute_correlation_per_sample() + for a_batch, a_batch_perturbed in zip( + generate_y_batches(), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_perturbed, + ) + similarity_scores.append(score) + pbar.update(1) + + # Save similarity scores in a result dictionary. + self.evaluation_scores[layer_name] = similarity_scores + + if self.return_average_correlation: + self.evaluation_scores = self.recompute_average_correlation_per_sample() + + elif self.return_last_correlation: + self.evaluation_scores = self.recompute_last_correlation_per_sample() if self.return_aggregate: - assert self.return_sample_correlation, ( - "You must set 'return_average_correlation_per_sample'" - " to True in order to compute te aggregat" + assert self.return_average_correlation or self.return_last_correlation, ( + "Set 'return_average_correlation' or 'return_last_correlation'" + " to True in order to compute the aggregate evaluation results." ) self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] @@ -336,12 +401,13 @@ def generate_y_batches(): return self.evaluation_scores - def compute_correlation_per_sample( + def recompute_average_correlation_per_sample( self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + ) -> List[float]: + assert isinstance(self.evaluation_scores, dict), ( "To compute the average correlation coefficient per sample for " - "Model Parameter Randomisation Test, 'last_result' " + "enhanced Model Parameter Randomisation Test, 'evaluation_scores' " "must be of type dict." ) layer_length = len( @@ -351,6 +417,8 @@ def compute_correlation_per_sample( for sample in results: for layer in self.evaluation_scores: + if layer == "orig": + continue results[sample].append(float(self.evaluation_scores[layer][sample])) results[sample] = np.mean(results[sample]) @@ -358,6 +426,58 @@ def compute_correlation_per_sample( return corr_coeffs + def recompute_last_correlation_per_sample( + self, + ) -> List[float]: + + assert isinstance(self.evaluation_scores, dict), ( + "To compute the last correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'evaluation_scores' " + "must be of type dict." + ) + # Return the correlation coefficient of the fully randomised model + # (excluding the non-randomised correlation). + corr_coeffs = list(self.evaluation_scores.values())[-1] + corr_coeffs = [float(c) for c in corr_coeffs] + return corr_coeffs + + def evaluate_instance( + self, + model: ModelInterface, + x: Optional[np.ndarray], + y: Optional[np.ndarray], + a: Optional[np.ndarray], + s: Optional[np.ndarray], + a_perturbed: Optional[np.ndarray] = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + i: integer + The evaluation instance. + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + a_perturbed: np.ndarray + The perturbed attributions. + + Returns + ------- + float + The evaluation results. + """ + # Compute similarity measure. + return self.similarity_func(a_perturbed.flatten(), a.flatten()) + def custom_preprocess( self, model: ModelInterface, @@ -388,8 +508,7 @@ def custom_preprocess( # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) - if a_batch is not None: - # Just to silence mypy warnings + if a_batch is not None: # Just to silence mypy warnings return None a_batch_chunks = [] @@ -415,5 +534,5 @@ def generate_explanations( def evaluate_batch(self, *args, **kwargs): raise RuntimeError( - "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + "`evaluate_batch` must never be called for `Model Parameter Randomisation`." ) diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py new file mode 100644 index 000000000..79310c1c9 --- /dev/null +++ b/quantus/metrics/randomisation/mprt.py @@ -0,0 +1,504 @@ +"""This module contains the implementation of the Model Parameter Randomisation Test metric.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + Collection, + Iterable, +) +import numpy as np +from tqdm.auto import tqdm + +from quantus.helpers import asserts +from quantus.helpers import warn +from quantus.helpers.model.model_interface import ModelInterface +from quantus.functions.normalise_func import normalise_by_max +from quantus.functions.similarity_func import correlation_spearman +from quantus.metrics.base import Metric +from quantus.helpers.enums import ( + ModelType, + DataType, + ScoreDirection, + EvaluationCategory, +) + + +class MPRT(Metric): + """ + Implementation of the Model Parameter Randomisation Test by Adebayo et. al., 2018. + + The Model Parameter Randomization measures the distance between the original attribution and a newly computed + attribution throughout the process of cascadingly/independently randomizing the model parameters of one layer + at a time. + + Assumptions: + - In the original paper multiple distance measures are taken: Spearman rank correlation (with and without abs), + HOG and SSIM. We have set Spearman as the default value. + + References: + 1) Julius Adebayo et al.: "Sanity Checks for Saliency Maps." NeurIPS (2018): 9525-9536. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Model Parameter Randomisation Test" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.LOWER + evaluation_category = EvaluationCategory.RANDOMISATION + + def __init__( + self, + similarity_func: Callable = None, + layer_order: str = "independent", + seed: int = 42, + return_sample_correlation: bool = False, + return_last_correlation: bool = False, + skip_layers: bool = False, + abs: bool = True, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Callable = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input, default=correlation_spearman. + layer_order: string + Indicated whether the model is randomized cascadingly or independently. + Set order=top_down for cascading randomization, set order=independent for independent randomization, + default="independent". + seed: integer + Seed used for the random generator, default=42. + return_sample_correlation: boolean + Indicates whether return one float per sample, representing the average + correlation coefficient across the layers for that sample. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=True. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + if normalise_func is None: + normalise_func = normalise_by_max + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if similarity_func is None: + similarity_func = correlation_spearman + self.similarity_func = similarity_func + self.layer_order = layer_order + self.seed = seed + self.return_sample_correlation = return_sample_correlation + self.return_last_correlation = return_last_correlation + self.skip_layers = skip_layers + + if self.return_sample_correlation and self.return_last_correlation: + raise ValueError( + f"Both 'return_sample_correlation' and 'return_last_correlation' cannot be True. Pick one." + ) + + # Results are returned/saved as a dictionary not like in the super-class as a list. + self.evaluation_scores = {} + + # Asserts and warnings. + asserts.assert_layer_order(layer_order=self.layer_order) + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "similarity metric 'similarity_func' and the order of " + "the layer randomisation 'layer_order'" + ), + citation=( + "Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., and Kim, B. " + "'Sanity Checks for Saliency Maps.' arXiv preprint," + " arXiv:1810.073292v3 (2018)" + ), + ) + + def __call__( + self, + model, + x_batch: np.array, + y_batch: np.array, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + The content of evaluation_scores will be appended to all_evaluation_scores (list) at the end of + the evaluation call. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + """ + + # Run deprecation warnings. + warn.deprecation_warnings(kwargs) + warn.check_kwargs(kwargs) + + data = self.general_preprocess( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + model_predict_kwargs=model_predict_kwargs, + softmax=softmax, + device=device, + ) + model = data["model"] + x_batch = data["x_batch"] + y_batch = data["y_batch"] + a_batch = data["a_batch"] + + # Results are returned/saved as a dictionary not as a list as in the super-class. + self.correlation_scores = np.zeros((len(x_batch))) + self.similarity_scores = {} + self.evaluation_scores = {} + + # Get number of iterations from number of layers. + n_layers = len(list(model.get_random_layer_generator(order=self.layer_order))) + + model_iterator = tqdm( + model.get_random_layer_generator(order=self.layer_order, seed=self.seed), + total=n_layers, + disable=not self.display_progressbar, + ) + + for l_ix, (layer_name, random_layer_model) in enumerate(model_iterator): + + similarity_scores = [None for _ in x_batch] + + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < len(model_iterator): + continue + + # Save correlation scores of no perturbation. + if ( + l_ix == 0 + ): # (l_ix == 0 and self.layer_order == "bottom_up") or (l_ix+1 == len(model_iterator) and self.layer_order == "top_down"): + + # Generate an explanation with original model. + a_batch_original = self.explain_func( + model=model.get_model(), + inputs=x_batch, + targets=y_batch, + **self.explain_func_kwargs, + ) + + batch_iterator = enumerate(zip(a_batch, a_batch_original)) + for instance_id, (a_instance, a_ori) in batch_iterator: + score = self.evaluate_instance( + model=model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_ori, + ) + similarity_scores[instance_id] = score + + # Save similarity scores in a result dictionary. + self.similarity_scores["orig"] = similarity_scores + + # Generate an explanation with perturbed model. + a_batch_perturbed = self.explain_func( + model=random_layer_model, + inputs=x_batch, + targets=y_batch, + **self.explain_func_kwargs, + ) + + batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) + for instance_id, (a_instance, a_instance_perturbed) in batch_iterator: + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_perturbed, + ) + similarity_scores[instance_id] = score + + # Save similarity scores in a result dictionary. + self.similarity_scores[layer_name] = similarity_scores + + # Call post-processing. + self.custom_postprocess( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + ) + + if self.return_sample_correlation: + self.correlation_scores = self.recompute_correlation_per_sample() + self.evaluation_scores = self.correlation_scores + + elif self.return_last_correlation: + self.correlation_scores = self.recompute_last_correlation_per_sample() + self.evaluation_scores = self.correlation_scores + + if self.return_aggregate: + assert self.return_sample_correlation, ( + "You must set 'return_average_correlation_per_sample'" + " to True in order to compute te aggregat" + ) + self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + + self.all_evaluation_scores.append(self.evaluation_scores) + + return self.evaluation_scores + + def evaluate_instance( + self, + model: ModelInterface, + x: Optional[np.ndarray], + y: Optional[np.ndarray], + a: Optional[np.ndarray], + s: Optional[np.ndarray], + a_perturbed: Optional[np.ndarray] = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + i: integer + The evaluation instance. + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + a_perturbed: np.ndarray + The perturbed attributions. + + Returns + ------- + float + The evaluation results. + """ + if self.normalise: + a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) + + if self.abs: + a_perturbed = np.abs(a_perturbed) + + # Compute distance measure. + return self.similarity_func(a_perturbed.flatten(), a.flatten()) + + def custom_preprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: np.ndarray, + custom_batch: Optional[np.ndarray], + ) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + custom_batch: any + Gives flexibility ot the user to use for evaluation, can hold any variable. + + Returns + ------- + None + """ + # Additional explain_func assert, as the one in general_preprocess() + # won't be executed when a_batch != None. + asserts.assert_explain_func(explain_func=self.explain_func) + + def recompute_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.similarity_scores, dict), ( + "To compute the average correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'similarity_scores' " + "must be of type dict." + ) + layer_length = len( + self.similarity_scores[list(self.similarity_scores.keys())[0]] + ) + results: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in results: + for layer in self.similarity_scores: + if layer == "orig": + continue + results[sample].append(float(self.similarity_scores[layer][sample])) + results[sample] = np.mean(results[sample]) + + corr_coeffs = list(results.values()) + + return corr_coeffs + + def recompute_last_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.similarity_scores, dict), ( + "To compute the last correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'similarity_scores' " + "must be of type dict." + ) + # Return the correlation coefficient of the fully randomised model. + corr_coeffs = list(self.similarity_scores.values())[-1] + + return corr_coeffs diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index 44f1ba9e6..c0f0c9cf0 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -16,7 +16,7 @@ def explain_func_stub(*args, **kwargs): return np.random.uniform(low=0, high=0.5, size=input_shape) -@pytest.mark.randomisation +@pytest.mark.mprt @pytest.mark.parametrize( "model,data,params,expected", [ @@ -221,7 +221,7 @@ def explain_func_stub(*args, **kwargs): }, }, }, - {"min": -1.0, "max": 1.01}, + {"min": -1.0, "max": 1.0}, ), ( lazy_fixture("titanic_model_tf"), @@ -236,7 +236,61 @@ def explain_func_stub(*args, **kwargs): }, "call": {"explain_func": explain_func_stub}, }, - {"min": -1.0, "max": 1.01}, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "top_down", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": True, + "return_last_correlation": False, + "skip_layers": False, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": False, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, ), ], ) @@ -279,23 +333,29 @@ def test_model_parameter_randomisation( ) return - scores_layers = ModelParameterRandomisation(**init_params)( + scores = ModelParameterRandomisation(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch, **call_params, ) - if isinstance(expected, float): - assert all( - s == expected for layer, scores in scores_layers.items() for s in scores - ), "Test failed." - else: - assert all( - ((s > expected["min"]) & (s < expected["max"])) - for layer, scores in scores_layers.items() - for s in scores - ), "Test failed." + + if isinstance(scores, dict): + for layer, scores_layer in scores.items(): + out_of_range_scores = [ + s for s in scores_layer if not (expected["min"] <= s <= expected["max"]) + ] + assert ( + not out_of_range_scores + ), f"Test failed for layer {layer}. Out of range scores: {out_of_range_scores}" + elif isinstance(scores, list): + out_of_range_scores = [ + s for s in scores if not (expected["min"] <= s <= expected["max"]) + ] + assert ( + not out_of_range_scores + ), f"Test failed. Out of range scores: {out_of_range_scores}" @pytest.mark.randomisation @@ -443,7 +503,7 @@ def test_model_parameter_randomisation( }, }, }, - {"min": -1.0, "max": 1.01}, + {"min": -1.0, "max": 1.0}, ), ( lazy_fixture("titanic_model_tf"), @@ -457,7 +517,7 @@ def test_model_parameter_randomisation( }, "call": {"softmax": True, "explain_func": explain_func_stub}, }, - {"min": -1.0, "max": 1.01}, + {"min": -1.0, "max": 1.0}, ), ], ) @@ -499,5 +559,5 @@ def test_random_logit( if isinstance(expected, float): assert all(s == expected for s in scores), "Test failed." else: - assert all(s > expected["min"] for s in scores), "Test failed." - assert all(s < expected["max"] for s in scores), "Test failed." + assert all(s >= expected["min"] for s in scores), "Test failed." + assert all(s <= expected["max"] for s in scores), "Test failed." From 44008e24f7ed5d5ab1a1d963da45121cb78236b3 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 20 Nov 2023 18:17:56 +0100 Subject: [PATCH 02/11] fixed tests --- tests/metrics/test_randomisation_metrics.py | 55 +++++++++++++++++---- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index c0f0c9cf0..bb858365a 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -16,7 +16,7 @@ def explain_func_stub(*args, **kwargs): return np.random.uniform(low=0, high=0.5, size=input_shape) -@pytest.mark.mprt +@pytest.mark.randomisation @pytest.mark.parametrize( "model,data,params,expected", [ @@ -292,6 +292,42 @@ def explain_func_stub(*args, **kwargs): }, {"min": -1.0, "max": 1.0}, ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": True, + "return_last_correlation": True, + "skip_layers": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"exception": ValueError}, + ), ], ) def test_model_parameter_randomisation( @@ -503,7 +539,7 @@ def test_model_parameter_randomisation( }, }, }, - {"min": -1.0, "max": 1.0}, + {"min": -1.0, "max": 1.1}, ), ( lazy_fixture("titanic_model_tf"), @@ -514,10 +550,11 @@ def test_model_parameter_randomisation( "normalise": True, "abs": True, "disable_warnings": True, + "similarity_func": correlation_pearson, }, "call": {"softmax": True, "explain_func": explain_func_stub}, }, - {"min": -1.0, "max": 1.0}, + {"min": -1.0, "max": 1.1}, ), ], ) @@ -555,9 +592,9 @@ def test_random_logit( a_batch=a_batch, **call_params, ) - - if isinstance(expected, float): - assert all(s == expected for s in scores), "Test failed." - else: - assert all(s >= expected["min"] for s in scores), "Test failed." - assert all(s <= expected["max"] for s in scores), "Test failed." + for s in scores: + if not (expected["min"] <= s <= expected["max"]): + print("!!!!", s) + assert all( + expected["min"] <= s <= expected["max"] for s in scores + ), f"Test failed with scores {scores}." From 5d1c708a12546ff4fa8988dbd578468ef21daba7 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 20 Nov 2023 18:18:31 +0100 Subject: [PATCH 03/11] remove FIXME for notifications of updated params --- quantus/metrics/randomisation/model_parameter_randomisation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index de78886c4..3658b0dbc 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -160,9 +160,6 @@ def __init__( # Results are returned/saved as a dictionary not like in the super-class as a list. self.evaluation_scores = {} - # TODO. ... renaming warning. - # default values change. - # Asserts and warnings. if self.return_average_correlation and self.return_last_correlation: raise ValueError( From 7d0105a84cdf72cf03a990797e1273044e73fb3e Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Mon, 20 Nov 2023 22:14:46 +0100 Subject: [PATCH 04/11] quantus updates: emprt implementation --- quantus/__init__.py | 2 +- quantus/functions/complexity_func.py | 102 ++++ quantus/functions/n_bins_func.py | 69 +++ quantus/functions/normalise_func.py | 2 +- quantus/helpers/constants.py | 10 + quantus/helpers/model/model_interface.py | 2 +- quantus/helpers/model/pytorch_model.py | 2 +- quantus/helpers/model/tf_model.py | 12 +- quantus/helpers/warn.py | 16 +- ...efficient_model_parameter_randomisation.py | 445 +++++++++++++++--- .../randomisation/{mprt.py => emprt.py} | 432 +++++++++++++---- .../model_parameter_randomisation.py | 22 +- tests/metrics/test_randomisation_metrics.py | 354 +++++++++++++- 13 files changed, 1283 insertions(+), 187 deletions(-) create mode 100644 quantus/functions/complexity_func.py create mode 100644 quantus/functions/n_bins_func.py rename quantus/metrics/randomisation/{mprt.py => emprt.py} (53%) diff --git a/quantus/__init__.py b/quantus/__init__.py index c566fa8ea..87053c15e 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -5,7 +5,7 @@ # Quantus project URL: . # Set the correct version. -__version__ = "0.4.5" +__version__ = "0.5.0" # Expose quantus.evaluate to the user. from quantus.evaluation import evaluate diff --git a/quantus/functions/complexity_func.py b/quantus/functions/complexity_func.py new file mode 100644 index 000000000..94c905ed4 --- /dev/null +++ b/quantus/functions/complexity_func.py @@ -0,0 +1,102 @@ +"""This module holds a collection of functions to compute the complexity (of explanations).""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +import scipy +import numpy as np + + +def entropy(a: np.array, x: np.array, **kwargs) -> float: + """ + Calculate entropy. + + Parameters + ---------- + a: np.ndarray + Array to calculate entropy on. + x: np.ndarray + Array to compute shape. + kwargs: optional + Keyword arguments. + + Returns + ------- + float: + A floating point, raning [0, inf]. + """ + assert (a >= 0).all(), "Entropy computation requires non-negative values." + + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a_reshaped = np.reshape(a, int(newshape)) + a_normalised = a_reshaped.astype(np.float64) / np.sum(np.abs(a_reshaped)) + return scipy.stats.entropy(pk=a_normalised) + + +def gini_coeffiient(a: np.array, x: np.array, **kwargs) -> float: + """ + Calculate Gini coefficient. + + Parameters + ---------- + a: np.ndarray + Array to calculate gini_coeffiient on. + x: np.ndarray + Array to compute shape. + kwargs: optional + Keyword arguments. + + Returns + ------- + float: + A floating point, ranging [0, 1]. + + """ + + if len(x.shape) == 1: + newshape = np.prod(x.shape) + else: + newshape = np.prod(x.shape[1:]) + + a = np.array(np.reshape(a, newshape), dtype=np.float64) + a += 0.0000001 + a = np.sort(a) + score = (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / ( + a.shape[0] * np.sum(a) + ) + return score + + +def discrete_entropy(a: np.array, x: np.array, **kwargs) -> float: + """ + Calculate discrete entropy of explanations with n_bins equidistant spaced bins + Parameters + ---------- + a: np.ndarray + Array to calculate entropy on. + x: np.ndarray + Array to compute shape. + kwargs: optional + Keyword arguments. + + n_bins: int + Number of bins. default is 100. + + Returns + ------- + float: + Discrete Entropy. + """ + + n_bins = kwargs.get("n_bins", 100) + + histogram, bins = np.histogram(a, bins=n_bins) + + return scipy.stats.entropy(pk=histogram) diff --git a/quantus/functions/n_bins_func.py b/quantus/functions/n_bins_func.py new file mode 100644 index 000000000..6679d6790 --- /dev/null +++ b/quantus/functions/n_bins_func.py @@ -0,0 +1,69 @@ +"""This module holds a collection of algorithms to calculate a number of bins to use for entropy calculation.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +import scipy +import numpy as np + + +def freedman_diaconis_rule(a_batch: np.array) -> int: + """Freedman–Diaconis' rule.""" + + iqr = np.percentile(a_batch, 75) - np.percentile(a_batch, 25) + n = a_batch[0].ndim + bin_width = 2 * iqr / np.power(n, 1 / 3) + + # Set a minimum value for bin_width to avoid division by very small numbers. + min_bin_width = 1e-6 + bin_width = max(bin_width, min_bin_width) + + # Calculate number of bins based on bin width. + n_bins = int((np.max(a_batch) - np.min(a_batch)) / bin_width) + + return n_bins + + +def scotts_rule(a_batch: np.array) -> int: + """Scott's rule.""" + + std = np.std(a_batch) + n = a_batch[0].ndim + + # Calculate bin width using Scott's rule. + bin_width = 3.5 * std / np.power(n, 1 / 3) + + # Calculate number of bins based on bin width. + n_bins = int((np.max(a_batch) - np.min(a_batch)) / bin_width) + + return n_bins + + +def square_root_choice(a_batch: np.array) -> int: + """Square-root choice rule.""" + + n = a_batch[0].ndim + n_bins = int(np.sqrt(n)) + + return n_bins + + +def sturges_formula(a_batch: np.array) -> int: + """Sturges' formula.""" + + n = a_batch[0].ndim + n_bins = int(np.log2(n) + 1) + + return n_bins + + +def rice_rule(a_batch: np.array) -> int: + """Rice Rule.""" + + n = a_batch[0].ndim + n_bins = int(2 * np.power(n, 1 / 3)) + + return n_bins diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 2d518aad4..2f5f61827 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -229,7 +229,7 @@ def normalise_by_average_second_moment_estimate( # Cast Sequence to tuple so numpy accepts it. normalise_axes = tuple(normalise_axes) - # Check that square root of the second momment estimatte is nonzero. + # Check that square root of the second momment estimate is nonzero. second_moment_sqrt = np.sqrt( np.sum(a ** 2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) diff --git a/quantus/helpers/constants.py b/quantus/helpers/constants.py index a785868ca..a60f02d4d 100644 --- a/quantus/helpers/constants.py +++ b/quantus/helpers/constants.py @@ -13,6 +13,7 @@ from quantus.functions.normalise_func import * from quantus.functions.perturb_func import * from quantus.functions.similarity_func import * +from quantus.functions import n_bins_func from quantus.metrics import * if sys.version_info >= (3, 8): @@ -158,6 +159,15 @@ } +AVAILABLE_N_BINS_ALGORITHMS = { + "Freedman Diaconis": n_bins_func.freedman_diaconis_rule, + "Scotts": n_bins_func.scotts_rule, + "Square Root": n_bins_func.square_root_choice, + "Sturges Formula": n_bins_func.sturges_formula, + "Rice": n_bins_func.rice_rule, +} + + def available_categories() -> List[str]: """ Retrieve the available metric categories in Quantus. diff --git a/quantus/helpers/model/model_interface.py b/quantus/helpers/model/model_interface.py index 3d3fc7605..68513d89b 100644 --- a/quantus/helpers/model/model_interface.py +++ b/quantus/helpers/model/model_interface.py @@ -20,7 +20,7 @@ class ModelInterface(ABC, Generic[M]): def __init__( self, model: M, - channel_first: bool = True, + channel_first: Optional[bool] = None, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, Any]] = None, ): diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index 60f1c01ad..fd1c61c8f 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -27,7 +27,7 @@ class PyTorchModel(ModelInterface[nn.Module]): def __init__( self, model: nn.Module, - channel_first: bool = True, + channel_first: bool = False, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = None, diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 22a7a8c85..c835c2880 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -40,15 +40,10 @@ class TensorFlowModel(ModelInterface[Model]): def __init__( self, model: Model, - channel_first: bool = True, + channel_first: bool = False, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, ...]] = None, ): - if model_predict_kwargs is None: - model_predict_kwargs = {} - # Disable progress bar while running inference on tf.keras.Model. - model_predict_kwargs["verbose"] = 0 - """ Initialisation of ModelInterface class. @@ -64,6 +59,11 @@ def __init__( model_predict_kwargs: dict, optional Keyword arguments to be passed to the model's predict method. """ + if model_predict_kwargs is None: + model_predict_kwargs = {} + # Disable progress bar while running inference on tf.keras.Model. + model_predict_kwargs["verbose"] = 0 + super().__init__( model=model, channel_first=channel_first, diff --git a/quantus/helpers/warn.py b/quantus/helpers/warn.py index 87913c6b1..169e2276b 100644 --- a/quantus/helpers/warn.py +++ b/quantus/helpers/warn.py @@ -30,17 +30,11 @@ def check_kwargs(kwargs): """ if kwargs: raise ValueError( - f"Please handle the following arguments: {kwargs}. " - "There were unexpected keyword arguments passed to the metric method. " - "Quantus has undergone heavy API-changes since the last release(s), " - "to make the kwargs-passing and error handling more robust and transparent. " - "Passing unexpected keyword arguments is now discouraged. Please adjust " - "your code to pass your kwargs in dictionaries to the arguments named " - "normalise_func_kwargs, explain_func_kwargs or model_predict_kwargs. " - "For evaluate function pass explain_func_kwargs and call_kwargs." - "And also, always make sure to check for typos. " - "If these API changes are not suitable for your project's needs, " - "please install quantus using 'pip install quantus==0.1.6' " + f"Unexpected keyword arguments encountered: {kwargs}. " + "To ensure proper usage, please refer to the 'get_params' method of the initialised metric instance " + "or consult the Quantus documentation. Avoid passing extraneous keyword arguments. " + "Ensure that your metric arguments are correctly structured, particularly 'normalise_func_kwargs', " + "'explain_func_kwargs', and 'model_predict_kwargs'. Additionally, always verify for any typos." ) diff --git a/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py b/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py index b78ff1ea6..e2967e2f7 100644 --- a/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py @@ -1,4 +1,4 @@ -"""This module contains the implementation of the Model Parameter Sensitivity metric.""" +"""This module contains the implementation of the Efficient Model Parameter Randomisation Test metric.""" # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. @@ -15,7 +15,9 @@ List, Optional, Union, + Tuple, Generator, + Iterable, ) @@ -24,7 +26,10 @@ from sklearn.utils import gen_batches from quantus.functions.similarity_func import correlation_spearman -from quantus.helpers import asserts, warn +from quantus.functions.complexity_func import discrete_entropy, entropy +from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate +from quantus.functions import n_bins_func +from quantus.helpers import asserts, warn, utils from quantus.helpers.enums import ( DataType, EvaluationCategory, @@ -39,16 +44,27 @@ else: from typing_extensions import final +AVAILABLE_N_BINS_ALGORITHMS = { + "Freedman Diaconis": n_bins_func.freedman_diaconis_rule, + "Scotts": n_bins_func.scotts_rule, + "Square Root": n_bins_func.square_root_choice, + "Sturges Formula": n_bins_func.sturges_formula, + "Rice": n_bins_func.rice_rule, +} + @final class EfficientModelParameterRandomisation(Metric): """ Implementation of the Efficient Model Parameter Randomization Method by Hedström et. al., 2023. - The Efficient Model Parameter Randomization measures replaces the layer-by-layer pairwise comparison between e and ˆe of MPRT by instead computing the relative rise in explanation complexity using only two model states, i.e., the original- and fully randomised model version + The Efficient Model Parameter Randomisation measures replaces the layer-by-layer pairwise comparison + between e and ˆe of MPRT by instead computing the relative rise in explanation complexity using only + two model states, i.e., the original- and fully randomised model version References: - 1) Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair the Model Parameter Randomisation Test." XAI in Action: Past, Present, and Future Applications. 2023. + 1) Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair the Model Parameter + Randomisation Test." XAI in Action: Past, Present, and Future Applications. 2023. Attributes: - _name: The name of the metric. @@ -58,20 +74,23 @@ class EfficientModelParameterRandomisation(Metric): - evaluation_category: What property/ explanation quality that this metric measures. """ - name = "Efficient Model Parameter Randomisation" + name = "Efficient Model Parameter Randomisation Test" data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.LOWER + score_direction = ScoreDirection.HIGHER evaluation_category = EvaluationCategory.RANDOMISATION def __init__( self, + complexity_func: Optional[Callable] = None, + complexity_func_kwargs: Optional[dict] = None, similarity_func: Optional[Callable] = None, - layer_order: str = "independent", + layer_order: str = "bottom_up", seed: int = 42, - return_sample_correlation: bool = False, - abs: bool = True, - normalise: bool = True, + compute_extra_scores: bool = False, + skip_layers: bool = True, + abs: bool = False, + normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -92,9 +111,15 @@ def __init__( default="independent". seed: integer Seed used for the random generator, default=42. - return_sample_correlation: boolean - Indicates whether return one float per sample, representing the average - correlation coefficient across the layers for that sample. + return_average_correlation: boolean + Indicates whether to return one float per sample, computing the average + correlation coefficient across the layers for a given sample. + return_last_correlation: boolean + Indicates whether to return one float per sample, computing the explanation + correlation coefficient for the full model randomisation (not layer-wise) of a sample. + skip_layers: boolean + Indicates if explanation similarity should be computed only once; between the + original and fully randomised model, instead of in a layer-by-layer basis. abs: boolean Indicates whether absolute operation is applied on the attribution, default=True. normalise: boolean @@ -132,17 +157,35 @@ def __init__( ) # Save metric-specific attributes. + if complexity_func is None: + complexity_func = discrete_entropy + + if complexity_func_kwargs is None: + complexity_func_kwargs = {} + + if normalise_func is None: + normalise_func = normalise_by_average_second_moment_estimate + + if normalise_func_kwargs is None: + normalise_func_kwargs = {} + if similarity_func is None: similarity_func = correlation_spearman + + self.complexity_func = complexity_func + self.complexity_func_kwargs = complexity_func_kwargs + self.normalise_func = normalise_func + self.abs = abs + self.normalise_func_kwargs = normalise_func_kwargs self.similarity_func = similarity_func self.layer_order = layer_order self.seed = seed - self.return_sample_correlation = return_sample_correlation + self.compute_extra_scores = compute_extra_scores + self.skip_layers = skip_layers # Results are returned/saved as a dictionary not like in the super-class as a list. self.evaluation_scores = {} - # Asserts and warnings. asserts.assert_layer_order(layer_order=self.layer_order) if not self.disable_warnings: warn.warn_parameterisation( @@ -169,7 +212,7 @@ def __call__( explain_func: Optional[Callable] = None, explain_func_kwargs: Optional[Dict] = None, model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = False, + softmax: Optional[bool] = True, device: Optional[str] = None, batch_size: int = 64, **kwargs, @@ -255,6 +298,7 @@ def __call__( warn.deprecation_warnings(kwargs) warn.check_kwargs(kwargs) self.batch_size = batch_size + data = self.general_preprocess( model=model, x_batch=x_batch, @@ -270,10 +314,12 @@ def __call__( device=device, ) model: ModelInterface = data["model"] # type: ignore + # Here _batch refers to full dataset. x_full_dataset = data["x_batch"] y_full_dataset = data["y_batch"] a_full_dataset = data["a_batch"] + # Results are returned/saved as a dictionary not as a list as in the super-class. self.evaluation_scores = {} @@ -286,71 +332,207 @@ def __call__( # Set property to False, so we display only 1 pbar. self._display_progressbar = False + # Get the number of bins for discrete entropy calculation. + if "n_bins" not in self.complexity_func_kwargs: + if a_batch is None: + a_batch = self.explain_batch( + model=model, + x_batch=x_full_dataset, + y_batch=y_full_dataset, + ) + self.find_n_bins( + a_batch=a_batch, + n_bins_default=self.complexity_func_kwargs.get("n_bins_default", 100), + min_n_bins=self.complexity_func_kwargs.get("min_n_bins", 10), + max_n_bins=self.complexity_func_kwargs.get("max_n_bins", 200), + debug=self.complexity_func_kwargs.get("debug", False), + ) + def generate_y_batches(): for batch in gen_batches(len(a_full_dataset), batch_size): yield a_full_dataset[batch.start : batch.stop] + self.explanation_scores_by_layer = {} + self.model_scores_by_layer = {} + with pbar as pbar: - for layer_name, random_layer_model in model.get_random_layer_generator( - order=self.layer_order, seed=self.seed + for l_ix, (layer_name, random_layer_model) in enumerate( + model.get_random_layer_generator(order=self.layer_order, seed=self.seed) ): pbar.desc = layer_name - similarity_scores = [] - # Generate explanations on modified model in batches - a_perturbed_generator = self.generate_explanations( - random_layer_model, x_full_dataset, y_full_dataset, batch_size - ) + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < n_layers: + continue + + if l_ix == 0: + # Generate explanations on modified model in batches. + a_original_generator = self.generate_explanations( + model.get_model(), x_full_dataset, y_full_dataset, batch_size + ) - for a_batch, a_batch_perturbed in zip( - generate_y_batches(), a_perturbed_generator - ): - for a_instance, a_instance_perturbed in zip( - a_batch, a_batch_perturbed + # Compute the complexity of explanations of the original model. + self.explanation_scores_by_layer["orig"] = [] + for a_batch, a_batch_original in zip( + generate_y_batches(), a_original_generator ): - result = self.similarity_func( - a_instance_perturbed.flatten(), a_instance.flatten() - ) - similarity_scores.append(result) - pbar.update(1) - # Save similarity scores in a result dictionary. - self.evaluation_scores[layer_name] = similarity_scores + for a_instance, a_instance_original in zip( + a_batch, a_batch_original + ): + score = self.evaluate_instance( + model=model, + x=x_batch[0], + y=None, + s=None, + a=a_instance_original, + ) + self.explanation_scores_by_layer["orig"].append(score) + pbar.update(1) + + # Compute the similarity of outputs of the original model. + self.model_scores_by_layer["orig"] = [] + y_preds = model.predict(x_full_dataset) + for y_ix, y_pred in enumerate(y_preds): + score = entropy(a=y_pred, x=y_pred) + self.model_scores_by_layer["orig"].append(score) + + # Generate explanations on modified model in batches. + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size + ) + + # Compute the complexity of explanations of the perturbed model. + self.explanation_scores_by_layer[layer_name] = [] + for a_batch, a_batch_perturbed in zip( + generate_y_batches(), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance_perturbed, + ) + self.explanation_scores_by_layer[layer_name].append(score) + pbar.update(1) + + # Wrap the model. + random_layer_model_wrapped = utils.get_wrapped_model( + model=random_layer_model, + channel_first=channel_first, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + ) + # Reshape input according to model (PyTorch or Keras/Torch). + x_full_dataset = model.shape_input( + x=x_full_dataset, + shape=x_full_dataset.shape, + channel_first=channel_first, + batched=True, + ) + + # Predict and save complexity scores of the perturbed model outputs. + self.model_scores_by_layer[layer_name] = [] + print("!!", x_full_dataset.shape) + y_preds = random_layer_model_wrapped.predict(x_full_dataset) + for y_ix, y_pred in enumerate(y_preds): + score = entropy(a=y_pred, x=y_pred) + self.model_scores_by_layer[layer_name].append(score) + + # Save evaluation scores as the relative rise in complexity. + explanation_scores = list(self.explanation_scores_by_layer.values()) + self.evaluation_scores = [ + (b - a) / a for a, b in zip(explanation_scores[0], explanation_scores[-1]) + ] + + # Compute extra scores and save the results in metric attributes. + if self.compute_extra_scores: + self.scores_extra = {} + + # Compute absolute deltas for explanation scores. + self.scores_extra["delta_explanation_scores"] = [ + b - a for a, b in zip(explanation_scores[0], explanation_scores[-1]) + ] + + # Compute simple fraction for explanation scores. + self.scores_extra["scores_fraction_explanation"] = [ + b / a if a != 0 else np.nan + for a, b in zip(explanation_scores[0], explanation_scores[-1]) + ] + + # Compute absolute deltas for model scores. + model_scores = list(self.model_scores_by_layer.values()) + self.scores_extra["scores_delta_model"] = [ + b - a for a, b in zip(model_scores[0], model_scores[-1]) + ] + + # Compute simple fraction for model scores. + self.scores_extra["scores_fraction_model"] = [ + b / a if a != 0 else np.nan + for a, b in zip(model_scores[0], model_scores[-1]) + ] + + # Compute delta skill score per sample (model versus explanation). + self.scores_extra["delta_explanation_vs_models"] = [ + b / a if a != 0 else np.nan + for a, b in zip( + self.scores_extra["scores_fraction_model"], + self.scores_extra["scores_fraction_explanation"], + ) + ] + # Compute the average complexity scores, per sample. + self.scores_extra[ + "scores_average_complexity" + ] = self.recompute_average_complexity_per_sample() - if self.return_sample_correlation: - self.evaluation_scores = self.compute_correlation_per_sample() + # Compute the correlation coefficient between the model and explanation complexity, per sample. + self.scores_extra[ + "scores_correlation_model_vs_explanation_complexity" + ] = self.recompute_model_explanation_correlation_per_sample() if self.return_aggregate: - assert self.return_sample_correlation, ( - "You must set 'return_average_correlation_per_sample'" - " to True in order to compute te aggregat" - ) self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + # Return all_evaluation_scores according to Quantus. self.all_evaluation_scores.append(self.evaluation_scores) return self.evaluation_scores - def compute_correlation_per_sample( + def evaluate_instance( self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - assert isinstance(self.evaluation_scores, dict), ( - "To compute the average correlation coefficient per sample for " - "Model Parameter Randomisation Test, 'last_result' " - "must be of type dict." - ) - layer_length = len( - self.evaluation_scores[list(self.evaluation_scores.keys())[0]] - ) - results: Dict[int, list] = {sample: [] for sample in range(layer_length)} - - for sample in results: - for layer in self.evaluation_scores: - results[sample].append(float(self.evaluation_scores[layer][sample])) - results[sample] = np.mean(results[sample]) + model: ModelInterface, + x: Optional[np.ndarray], + y: Optional[np.ndarray], + a: Optional[np.ndarray], + s: Optional[np.ndarray], + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - corr_coeffs = list(results.values()) + Parameters + ---------- + i: integer + The evaluation instance. + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. - return corr_coeffs + Returns + ------- + float + The evaluation results. + """ + # Compute complexity measure. + return self.complexity_func(a=a, x=x, **self.complexity_func_kwargs) def custom_preprocess( self, @@ -382,8 +564,7 @@ def custom_preprocess( # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. asserts.assert_explain_func(explain_func=self.explain_func) - if a_batch is not None: - # Just to silence mypy warnings + if a_batch is not None: # Just to silence mypy warnings return None a_batch_chunks = [] @@ -409,5 +590,145 @@ def generate_explanations( def evaluate_batch(self, *args, **kwargs): raise RuntimeError( - "`evaluate_batch` must never be called for `ModelParameterRandomisation`." + "`evaluate_batch` must never be called for `Model Parameter Randomisation`." + ) + + def recompute_model_explanation_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.explanation_scores_by_layer, dict), ( + "To compute the correlation between model and explanation per sample for " + "enhanced Model Parameter Randomisation Test, 'explanation_scores' " + "must be of type dict." + ) + layer_length = len( + self.explanation_scores_by_layer[ + list(self.explanation_scores_by_layer.keys())[0] + ] ) + explanation_scores: Dict[int, list] = { + sample: [] for sample in range(layer_length) + } + model_scores: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in explanation_scores.keys(): + for layer in self.explanation_scores_by_layer: + explanation_scores[sample].append( + float(self.explanation_scores_by_layer[layer][sample]) + ) + model_scores[sample].append( + float(self.model_scores_by_layer[layer][sample]) + ) + + corr_coeffs = [] + for sample in explanation_scores.keys(): + corr_coeffs.append( + self.similarity_func(model_scores[sample], explanation_scores[sample]) + ) + + return corr_coeffs + + def recompute_average_complexity_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.explanation_scores_by_layer, dict), ( + "To compute the average correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'explanation_scores' " + "must be of type dict." + ) + layer_length = len( + self.explanation_scores_by_layer[ + list(self.explanation_scores_by_layer.keys())[0] + ] + ) + results: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in results: + for layer in self.explanation_scores_by_layer: + if layer == "orig": + continue + results[sample].append( + float(self.explanation_scores_by_layer[layer][sample]) + ) + results[sample] = np.mean(results[sample]) + + corr_coeffs = list(results.values()) + + return corr_coeffs + + def recompute_last_correlation_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.explanation_scores_by_layer, dict), ( + "To compute the last correlation coefficient per sample for " + "Model Parameter Randomisation Test, 'explanation_scores' " + "must be of type dict." + ) + corr_coeffs = list(self.explanation_scores_by_layer.values())[-1] + corr_coeffs = [float(c) for c in corr_coeffs] + return corr_coeffs + + def find_n_bins( + self, + a_batch: np.array, + n_bins_default: int = 100, + min_n_bins: int = 10, + max_n_bins: int = 200, + debug: bool = True, + ) -> None: + """ + Find the number of bins for discrete entropy calculation. + + Parameters + ---------- + a_batch: np.array + Explanatio array to calculate entropy on. + n_bins_default: int + Default number of bins to use if no rule is found, default=100. + min_n_bins: int + Minimum number of bins to use, default=10. + max_n_bins: int + Maximum number of bins to use, default=200. + debug: boolean + Indicates whether to print debug information, default=True. + + Returns + ------- + None + """ + if self.normalise: + a_batch = self.normalise_func(a_batch, **self.normalise_func_kwargs) + + if self.abs: + a_batch = np.abs(a_batch) + + if debug: + print(f"\tMax and min value of a_batch=({a_batch.min()}, {a_batch.max()})") + + try: + rule_name = self.complexity_func_kwargs.get("rule", None) + rule_function = AVAILABLE_N_BINS_ALGORITHMS.get(rule_name) + except: + print( + f"Attempted to use a rule '{rule_name}' that is not available in existing rules: " + f"{AVAILABLE_N_BINS_ALGORITHMS.keys()}." + ) + + if not rule_function: + self.complexity_func_kwargs["n_bins"] = n_bins_default + if debug: + print(f"\tNo rule found, 'n_bins' set to 100.") + return None + + n_bins = rule_function(a_batch=a_batch) + n_bins = max(min(n_bins, max_n_bins), min_n_bins) + self.complexity_func_kwargs["n_bins"] = n_bins + + if debug: + print( + f"\tRule '{rule_name}' -> n_bins={n_bins} but with min={min_n_bins} " + f"and max={max_n_bins}, 'n_bins' set to {self.complexity_func_kwargs['n_bins']}." + ) diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/emprt.py similarity index 53% rename from quantus/metrics/randomisation/mprt.py rename to quantus/metrics/randomisation/emprt.py index 79310c1c9..6edbf0def 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/emprt.py @@ -1,4 +1,4 @@ -"""This module contains the implementation of the Model Parameter Randomisation Test metric.""" +"""This module contains the implementation of the enhanced Model Parameter Randomisation Test metric.""" # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. @@ -17,14 +17,18 @@ Collection, Iterable, ) +import os import numpy as np from tqdm.auto import tqdm +import torch from quantus.helpers import asserts from quantus.helpers import warn +from quantus.helpers import utils from quantus.helpers.model.model_interface import ModelInterface from quantus.functions.normalise_func import normalise_by_max -from quantus.functions.similarity_func import correlation_spearman + +# from quantus.functions import complexity_func from quantus.metrics.base import Metric from quantus.helpers.enums import ( ModelType, @@ -34,20 +38,14 @@ ) -class MPRT(Metric): +class eMPRT(Metric): """ - Implementation of the Model Parameter Randomisation Test by Adebayo et. al., 2018. - - The Model Parameter Randomization measures the distance between the original attribution and a newly computed - attribution throughout the process of cascadingly/independently randomizing the model parameters of one layer - at a time. + Implementation of the NAME by AUTHOR et. al., 2023. - Assumptions: - - In the original paper multiple distance measures are taken: Spearman rank correlation (with and without abs), - HOG and SSIM. We have set Spearman as the default value. + INSERT DESC. References: - 1) Julius Adebayo et al.: "Sanity Checks for Saliency Maps." NeurIPS (2018): 9525-9536. + 1) INSERT SOURCE Attributes: - _name: The name of the metric. @@ -57,22 +55,35 @@ class MPRT(Metric): - evaluation_category: What property/ explanation quality that this metric measures. """ - name = "Model Parameter Randomisation Test" + name = "Enhanced Model Parameter Randomisation Test" data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.LOWER + score_direction = ScoreDirection.HIGHER evaluation_category = EvaluationCategory.RANDOMISATION def __init__( self, - similarity_func: Callable = None, - layer_order: str = "independent", + complexity_func: Optional[Callable] = None, + complexity_func_kwargs: Optional[dict] = None, + layer_order: str = "bottom_up", + nr_samples: Optional[int] = None, seed: int = 42, - return_sample_correlation: bool = False, - return_last_correlation: bool = False, + compute_delta: bool = True, + compute_rate_of_change: bool = True, + compute_delta_explanation_vs_model: bool = True, + compute_correlation: bool = True, + compute_last_complexity: bool = True, + return_delta_explanation_vs_model: bool = False, + return_fraction: bool = False, + return_rate_of_change: bool = True, + return_average_sample_score: bool = False, + return_correlation: bool = False, + return_last_complexity: bool = False, + return_delta_explanation: bool = False, skip_layers: bool = False, - abs: bool = True, - normalise: bool = True, + similarity_func: Optional[Callable] = None, + abs: bool = False, + normalise: bool = False, normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, return_aggregate: bool = False, @@ -93,7 +104,7 @@ def __init__( default="independent". seed: integer Seed used for the random generator, default=42. - return_sample_correlation: boolean + return_average_sample_score: boolean Indicates whether return one float per sample, representing the average correlation coefficient across the layers for that sample. abs: boolean @@ -134,25 +145,64 @@ def __init__( **kwargs, ) + # Set seed for reproducibility. + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.enabled = False + # Save metric-specific attributes. + if complexity_func is None: + complexity_func = discrete_entropy + + if complexity_func_kwargs is None: + complexity_func_kwargs = {} + if similarity_func is None: - similarity_func = correlation_spearman + similarity_func = similarity_func.correlation_spearman + + self.complexity_func = complexity_func + self.complexity_func_kwargs = complexity_func_kwargs self.similarity_func = similarity_func self.layer_order = layer_order - self.seed = seed - self.return_sample_correlation = return_sample_correlation - self.return_last_correlation = return_last_correlation + self.nr_samples = nr_samples + self.compute_delta = compute_delta + self.compute_rate_of_change = compute_rate_of_change + self.compute_delta_explanation_vs_model = compute_delta_explanation_vs_model + self.compute_correlation = compute_correlation + self.compute_last_complexity = compute_last_complexity + self.return_average_sample_score = return_average_sample_score + self.return_fraction = return_fraction + self.return_rate_of_change = return_rate_of_change + self.return_delta_explanation_vs_model = return_delta_explanation_vs_model + self.return_correlation = return_correlation + self.return_last_complexity = return_last_complexity + self.return_delta_explanation = return_delta_explanation self.skip_layers = skip_layers - if self.return_sample_correlation and self.return_last_correlation: - raise ValueError( - f"Both 'return_sample_correlation' and 'return_last_correlation' cannot be True. Pick one." + # Asserts and warnings. + assert ( + sum( + [ + self.return_fraction, + self.return_average_sample_score, + self.return_correlation, + self.return_last_complexity, + self.return_delta_explanation, + self.return_delta_explanation_vs_model, + self.return_rate_of_change, + ] ) + == 1 + ), "Set one of the possible 'return' arguments to True." - # Results are returned/saved as a dictionary not like in the super-class as a list. - self.evaluation_scores = {} - - # Asserts and warnings. asserts.assert_layer_order(layer_order=self.layer_order) if not self.disable_warnings: warn.warn_parameterisation( @@ -183,6 +233,7 @@ def __call__( device: Optional[str] = None, batch_size: int = 64, custom_batch: Optional[Any] = None, + attributions_path: str = None, **kwargs, ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: """ @@ -280,39 +331,51 @@ def __call__( softmax=softmax, device=device, ) + + # Get model and data. model = data["model"] x_batch = data["x_batch"] y_batch = data["y_batch"] a_batch = data["a_batch"] - # Results are returned/saved as a dictionary not as a list as in the super-class. - self.correlation_scores = np.zeros((len(x_batch))) - self.similarity_scores = {} - self.evaluation_scores = {} - # Get number of iterations from number of layers. n_layers = len(list(model.get_random_layer_generator(order=self.layer_order))) - model_iterator = tqdm( - model.get_random_layer_generator(order=self.layer_order, seed=self.seed), + model.get_random_layer_generator(order=self.layer_order), total=n_layers, disable=not self.display_progressbar, ) - for l_ix, (layer_name, random_layer_model) in enumerate(model_iterator): + # Get the number of bins for discrete entropy calculation. + if "n_bins" not in self.complexity_func_kwargs: + self.find_n_bins( + a_batch=a_batch, + n_bins_default=self.complexity_func_kwargs.get("n_bins_default", 100), + min_n_bins=self.complexity_func_kwargs.get("min_n_bins", 10), + max_n_bins=self.complexity_func_kwargs.get("max_n_bins", 200), + debug=self.complexity_func_kwargs.get("debug", False), + ) - similarity_scores = [None for _ in x_batch] + # Compute the explanation_scores given uniformly sampled explanation. + if self.nr_samples is None: + self.nr_samples = len(a_batch) + + # Initialise arrays. + self.delta_explanation_scores = np.zeros((self.nr_samples)) + self.delta_model_scores = np.zeros((self.nr_samples)) + self.fraction_explanation_scores = np.zeros((self.nr_samples)) + self.fraction_model_scores = np.zeros((self.nr_samples)) + self.delta_explanation_vs_models = np.zeros((self.nr_samples)) + self.correlation_scores = np.zeros((self.nr_samples)) + self.rate_of_change_scores = np.zeros((self.nr_samples)) + self.explanation_scores = {} + self.model_scores = {} - # Skip layers if computing delta. - if self.skip_layers and (l_ix + 1) < len(model_iterator): - continue + for l_ix, (layer_name, random_layer_model) in enumerate(model_iterator): - # Save correlation scores of no perturbation. - if ( - l_ix == 0 - ): # (l_ix == 0 and self.layer_order == "bottom_up") or (l_ix+1 == len(model_iterator) and self.layer_order == "top_down"): + if l_ix == 0: - # Generate an explanation with original model. + # Generate an explanation with perturbed model. a_batch_original = self.explain_func( model=model.get_model(), inputs=x_batch, @@ -320,20 +383,33 @@ def __call__( **self.explain_func_kwargs, ) - batch_iterator = enumerate(zip(a_batch, a_batch_original)) - for instance_id, (a_instance, a_ori) in batch_iterator: + self.explanation_scores["orig"] = [] + for a_ix, a_ori in enumerate(a_batch_original): score = self.evaluate_instance( model=model, - x=None, + x=x_batch[0], y=None, s=None, - a=a_instance, - a_perturbed=a_ori, + a=a_ori, ) - similarity_scores[instance_id] = score + self.explanation_scores["orig"].append(score) + + # Compute entropy of the output layer. + self.model_scores["orig"] = [] + for y_ix, y_pred in enumerate(model.predict(x_batch)): + score = entropy(a=y_pred, x=y_pred) + self.model_scores["orig"].append(score) - # Save similarity scores in a result dictionary. - self.similarity_scores["orig"] = similarity_scores + # Skip layers if computing delta. + if ( + self.skip_layers + and self.compute_delta + and (l_ix + 1) < len(model_iterator) + ): + continue + + # Score explanation complexity. + explanation_scores = [] # Generate an explanation with perturbed model. a_batch_perturbed = self.explain_func( @@ -343,20 +419,71 @@ def __call__( **self.explain_func_kwargs, ) + # Get id for storing data. + if attributions_path is not None: + savepath = os.path.join(attributions_path, f"{l_ix}-{layer_name}") + os.makedirs(savepath, exist_ok=True) + last_id = 0 + for fname in os.listdir(savepath): + if "original_attribution_" in fname: + id = ( + int(fname.split("original_attribution_")[1].split(".")[0]) + > last_id + ) + if id > last_id: + last_id = id + batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) - for instance_id, (a_instance, a_instance_perturbed) in batch_iterator: + for instance_id, (a_ix, a_perturbed) in batch_iterator: score = self.evaluate_instance( model=random_layer_model, - x=None, + x=x_batch[0], y=None, s=None, - a=a_instance, - a_perturbed=a_instance_perturbed, + a=a_perturbed, ) - similarity_scores[instance_id] = score + explanation_scores.append(score) + + # Save data. + if attributions_path is not None: + np.save( + os.path.join(savepath, f"input_{last_id+instance_id}.npy"), + x_batch[instance_id], + ) + np.save( + os.path.join( + savepath, f"original_attribution_{last_id+instance_id}.npy" + ), + a_ix, + ) + np.save( + os.path.join( + savepath, f"perturbed_attribution_{last_id+instance_id}.npy" + ), + a_perturbed, + ) + + # Score the model complexity. + model_scores = [] - # Save similarity scores in a result dictionary. - self.similarity_scores[layer_name] = similarity_scores + # Wrap the model. + random_layer_model_wrapped = utils.get_wrapped_model( + model=random_layer_model, + channel_first=channel_first, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + ) + + # Predict and save scores. + y_preds = random_layer_model_wrapped.predict(x_batch) + for y_ix, y_pred in enumerate(y_preds): + score = entropy(a=y_pred, x=y_pred) + model_scores.append(score) + + # Save explanation_scores scores in a result dictionary. + self.explanation_scores[layer_name] = explanation_scores + self.model_scores[layer_name] = model_scores # Call post-processing. self.custom_postprocess( @@ -367,21 +494,84 @@ def __call__( s_batch=s_batch, ) - if self.return_sample_correlation: - self.correlation_scores = self.recompute_correlation_per_sample() - self.evaluation_scores = self.correlation_scores + # If compute correlation score (model and explanations) + if self.compute_correlation: + self.correlation_scores = ( + self.recompute_model_explanation_correlation_per_sample() + ) - elif self.return_last_correlation: - self.correlation_scores = self.recompute_last_correlation_per_sample() + # If compute the last complexity score. + if self.compute_last_complexity: + self.last_complexity_scores = self.recompute_last_correlation_per_sample() + + # If compute delta score per sample (model and explanations). + if self.compute_delta: + + # Compute deltas for explanation scores. + scores = list(self.explanation_scores.values()) + self.delta_explanation_scores = [ + b - a for a, b in zip(scores[0], scores[-1]) + ] + + # Compute deltas for model scores. + scores = list(self.model_scores.values()) + self.delta_model_scores = [b - a for a, b in zip(scores[0], scores[-1])] + + # Compute fraction for explanation scores. + scores = list(self.explanation_scores.values()) + self.fraction_explanation_scores = [ + b / a if a != 0 else np.nan for a, b in zip(scores[0], scores[-1]) + ] # eMPRT original! + + # Compute fraction for explanation scores. + scores = list(self.model_scores.values()) + self.fraction_model_scores = [ + b / a if a != 0 else np.nan for a, b in zip(scores[0], scores[-1]) + ] + + # If compute delta skill score per sample (model and explanations). + if self.compute_delta_explanation_vs_model: + self.delta_explanation_vs_models = [ + b / a if a != 0 else np.nan + for a, b in zip( + self.fraction_model_scores, self.fraction_explanation_scores + ) + ] + + # If compute delta skill score per sample (model and explanations). + if self.compute_rate_of_change: + scores = list(self.explanation_scores.values()) + self.rate_of_change_scores = [ + (b - a) / a for a, b in zip(scores[0], scores[-1]) + ] + + # If return one score per sample. + if self.return_average_sample_score: + self.evaluation_scores = self.recompute_average_complexity_per_sample() + + # If return delta score per sample. + if self.return_fraction: + self.evaluation_scores = self.fraction_explanation_scores + + # If return delta score per sample. + if self.return_delta_explanation_vs_model: + self.evaluation_scores = self.delta_explanation_vs_models + + # If return delta score per sample. + if self.return_correlation: self.evaluation_scores = self.correlation_scores + if self.return_last_complexity: + self.evaluation_scores = self.last_complexity_scores + + if self.return_rate_of_change: + self.evaluation_scores = self.rate_of_change_scores + + # If return one aggregate score for all samples. if self.return_aggregate: - assert self.return_sample_correlation, ( - "You must set 'return_average_correlation_per_sample'" - " to True in order to compute te aggregat" - ) self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + # Return all_evaluation_scores according to Quantus. self.all_evaluation_scores.append(self.evaluation_scores) return self.evaluation_scores @@ -393,7 +583,6 @@ def evaluate_instance( y: Optional[np.ndarray], a: Optional[np.ndarray], s: Optional[np.ndarray], - a_perturbed: Optional[np.ndarray] = None, ) -> float: """ Evaluate instance gets model and data for a single instance as input and returns the evaluation result. @@ -412,8 +601,6 @@ def evaluate_instance( The explanation to be evaluated on an instance-basis. s: np.ndarray The segmentation to be evaluated on an instance-basis. - a_perturbed: np.ndarray - The perturbed attributions. Returns ------- @@ -421,13 +608,13 @@ def evaluate_instance( The evaluation results. """ if self.normalise: - a_perturbed = self.normalise_func(a_perturbed, **self.normalise_func_kwargs) + a = self.normalise_func(a, **self.normalise_func_kwargs) if self.abs: - a_perturbed = np.abs(a_perturbed) + a = np.abs(a) # Compute distance measure. - return self.similarity_func(a_perturbed.flatten(), a.flatten()) + return self.complexity_func(a=a, x=x, **self.complexity_func_kwargs) def custom_preprocess( self, @@ -462,27 +649,60 @@ def custom_preprocess( """ # Additional explain_func assert, as the one in general_preprocess() # won't be executed when a_batch != None. + asserts.assert_explain_func(explain_func=self.explain_func) - def recompute_correlation_per_sample( + def recompute_model_explanation_correlation_per_sample( self, ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - assert isinstance(self.similarity_scores, dict), ( + assert isinstance(self.explanation_scores, dict), ( + "To compute the correlation between model and explanation per sample for " + "enhanced Model Parameter Randomisation Test, 'explanation_scores' " + "must be of type dict." + ) + layer_length = len( + self.explanation_scores[list(self.explanation_scores.keys())[0]] + ) + explanation_scores: Dict[int, list] = { + sample: [] for sample in range(layer_length) + } + model_scores: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in explanation_scores.keys(): + for layer in self.explanation_scores: + explanation_scores[sample].append( + float(self.explanation_scores[layer][sample]) + ) + model_scores[sample].append(float(self.model_scores[layer][sample])) + + corr_coeffs = [] + for sample in explanation_scores.keys(): + corr_coeffs.append( + self.similarity_func(model_scores[sample], explanation_scores[sample]) + ) + + return corr_coeffs + + def recompute_average_complexity_per_sample( + self, + ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + + assert isinstance(self.explanation_scores, dict), ( "To compute the average correlation coefficient per sample for " - "enhanced Model Parameter Randomisation Test, 'similarity_scores' " + "enhanced Model Parameter Randomisation Test, 'explanation_scores' " "must be of type dict." ) layer_length = len( - self.similarity_scores[list(self.similarity_scores.keys())[0]] + self.explanation_scores[list(self.explanation_scores.keys())[0]] ) results: Dict[int, list] = {sample: [] for sample in range(layer_length)} for sample in results: - for layer in self.similarity_scores: + for layer in self.explanation_scores: if layer == "orig": continue - results[sample].append(float(self.similarity_scores[layer][sample])) + results[sample].append(float(self.explanation_scores[layer][sample])) results[sample] = np.mean(results[sample]) corr_coeffs = list(results.values()) @@ -493,12 +713,46 @@ def recompute_last_correlation_per_sample( self, ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - assert isinstance(self.similarity_scores, dict), ( + assert isinstance(self.explanation_scores, dict), ( "To compute the last correlation coefficient per sample for " - "enhanced Model Parameter Randomisation Test, 'similarity_scores' " + "Model Parameter Randomisation Test, 'explanation_scores' " "must be of type dict." ) - # Return the correlation coefficient of the fully randomised model. - corr_coeffs = list(self.similarity_scores.values())[-1] + corr_coeffs = list(self.explanation_scores.values())[-1] return corr_coeffs + + def find_n_bins( + self, + a_batch: np.array, + n_bins_default: int = 100, + min_n_bins: int = 10, + max_n_bins: int = 200, + debug: bool = True, + ) -> None: + + if self.normalise: + a_batch = self.normalise_func(a, **self.normalise_func_kwargs) + if self.abs: + a_batch = np.abs(a_batch) + + rule_name = self.complexity_func_kwargs.get("rule", None) + rule = RULES_N_BINS.get(rule_name) + + if debug: + print(f"\tMax and min value of a_batch=({a_batch.min()}, {a_batch.max()})") + + if not rule: + self.complexity_func_kwargs["n_bins"] = n_bins_default + if debug: + print(f"\tNo rule found, 'n_bins' set to 100.") + return None + + n_bins = rule(a_batch=a_batch) + n_bins = max(min(n_bins, max_n_bins), min_n_bins) + self.complexity_func_kwargs["n_bins"] = n_bins + + if debug: + print( + f"\tRule '{rule_name}' -> n_bins={n_bins} but with min={min_n_bins} and max={max_n_bins}, 'n_bins' set to {self.complexity_func_kwargs['n_bins']}." + ) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/model_parameter_randomisation.py index 3658b0dbc..ac814b269 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/model_parameter_randomisation.py @@ -326,18 +326,19 @@ def generate_y_batches(): ): pbar.desc = layer_name - similarity_scores = [] - # Skip layers if computing delta. if self.skip_layers and (l_ix + 1) < n_layers: continue if l_ix == 0: - # Generate explanations on modified model in batches + + # Generate explanations on modified model in batches. a_original_generator = self.generate_explanations( model.get_model(), x_full_dataset, y_full_dataset, batch_size ) + # Compute the similarity of explanations of the original model. + self.evaluation_scores["original"] = [] for a_batch, a_batch_original in zip( generate_y_batches(), a_original_generator ): @@ -352,17 +353,18 @@ def generate_y_batches(): a=a_instance, a_perturbed=a_instance_original, ) - similarity_scores.append(score) + # Save similarity scores in a result dictionary. + self.evaluation_scores["original"].append(score) pbar.update(1) - # Save similarity scores in a result dictionary. - self.evaluation_scores["original"] = similarity_scores + self.evaluation_scores[layer_name] = [] - # Generate explanations on modified model in batches + # Generate explanations on modified model in batches. a_perturbed_generator = self.generate_explanations( random_layer_model, x_full_dataset, y_full_dataset, batch_size ) + # Compute the similarity of explanations of the perturbed model. for a_batch, a_batch_perturbed in zip( generate_y_batches(), a_perturbed_generator ): @@ -375,12 +377,9 @@ def generate_y_batches(): a=a_instance, a_perturbed=a_instance_perturbed, ) - similarity_scores.append(score) + self.evaluation_scores[layer_name].append(score) pbar.update(1) - # Save similarity scores in a result dictionary. - self.evaluation_scores[layer_name] = similarity_scores - if self.return_average_correlation: self.evaluation_scores = self.recompute_average_correlation_per_sample() @@ -394,6 +393,7 @@ def generate_y_batches(): ) self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + # Return all_evaluation_scores according to Quantus. self.all_evaluation_scores.append(self.evaluation_scores) return self.evaluation_scores diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index bb858365a..bef8ea483 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -5,9 +5,14 @@ import numpy as np from quantus.functions.explanation_func import explain +from quantus.functions import complexity_func, n_bins_func from quantus.functions.similarity_func import correlation_spearman, correlation_pearson from quantus.helpers.model.model_interface import ModelInterface -from quantus.metrics.randomisation import ModelParameterRandomisation, RandomLogit +from quantus.metrics.randomisation import ( + ModelParameterRandomisation, + RandomLogit, + EfficientModelParameterRandomisation, +) def explain_func_stub(*args, **kwargs): @@ -592,9 +597,350 @@ def test_random_logit( a_batch=a_batch, **call_params, ) - for s in scores: - if not (expected["min"] <= s <= expected["max"]): - print("!!!!", s) assert all( expected["min"] <= s <= expected["max"] for s in scores ), f"Test failed with scores {scores}." + + +@pytest.mark.emprt +@pytest.mark.parametrize( + "model,data,params,expected", + [ + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": False, + "display_progressbar": False, + "aggregate_func": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": False, + "display_progressbar": False, + "compute_extra_scores": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "init": { + "complexity_func": complexity_func.discrete_entropy, + "complexity_func_kwargs": {"n_bins": 10}, + "layer_order": "bottom_up", + "similarity_func": correlation_pearson, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "complexity_func_kwargs": {"rule": "Scotts"}, + "similarity_func": correlation_pearson, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Gradient", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model_tf"), + lazy_fixture("load_mnist_images_tf"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "complexity_func_kwargs": {"rule": "Square Root"}, + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "VanillaGradients", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model_tf"), + lazy_fixture("load_mnist_images_tf"), + { + "a_batch_generate": False, + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_pearson, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Gradient", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": True, + "display_progressbar": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": True, + "display_progressbar": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("titanic_model_torch"), + lazy_fixture("titanic_dataset"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "IntegratedGradients", + "reduce_axes": (), + }, + }, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("titanic_model_tf"), + lazy_fixture("titanic_dataset"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "skip_layers": False, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "compute_extra_scores": False, + "skip_layers": False, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "compute_extra_scores": False, + "skip_layers": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "complexity_func": complexity_func.discrete_entropy, + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "compute_extra_scores": False, + "skip_layers": True, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1000000000, "max": 1000000000}, + ), + ], +) +def test_efficient_model_parameter_randomisation( + model: ModelInterface, + data: np.ndarray, + params: dict, + expected: Union[float, dict, bool], +): + x_batch, y_batch = ( + data["x_batch"], + data["y_batch"], + ) + + init_params = params.get("init", {}) + call_params = params.get("call", {}) + + if params.get("a_batch_generate", True): + explain = call_params["explain_func"] + explain_func_kwargs = call_params.get("explain_func_kwargs", {}) + a_batch = explain( + model=model, + inputs=x_batch, + targets=y_batch, + **explain_func_kwargs, + ) + elif "a_batch" in data: + a_batch = data["a_batch"] + else: + a_batch = None + + if "exception" in expected: + with pytest.raises(expected["exception"]): + scores_layers = EfficientModelParameterRandomisation(**init_params)( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + **call_params, + ) + return + + scores = EfficientModelParameterRandomisation(**init_params)( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + **call_params, + ) + + out_of_range_scores = [ + s for s in scores if not (expected["min"] <= s <= expected["max"]) + ] + assert ( + not out_of_range_scores + ), f"Test failed. Out of range scores: {out_of_range_scores}" From c91bb8d10bf858408f6107c13564615a3ce3e9f7 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 23 Nov 2023 12:14:07 +0100 Subject: [PATCH 05/11] test fixed mprt, emprt and smprt --- CONTRIBUTING.md | 2 +- quantus/evaluation.py | 5 +- quantus/functions/explanation_func.py | 276 +++--- quantus/helpers/constants.py | 5 +- quantus/helpers/model/model_interface.py | 30 +- quantus/helpers/utils.py | 37 +- quantus/metrics/axiomatic/completeness.py | 2 +- quantus/metrics/axiomatic/input_invariance.py | 2 +- quantus/metrics/axiomatic/non_sensitivity.py | 2 +- quantus/metrics/base.py | 2 +- quantus/metrics/complexity/complexity.py | 2 +- .../complexity/effective_complexity.py | 2 +- quantus/metrics/complexity/sparseness.py | 2 +- .../faithfulness/faithfulness_correlation.py | 2 +- .../faithfulness/faithfulness_estimate.py | 2 +- quantus/metrics/faithfulness/infidelity.py | 2 +- quantus/metrics/faithfulness/irof.py | 2 +- quantus/metrics/faithfulness/monotonicity.py | 2 +- .../faithfulness/monotonicity_correlation.py | 4 +- .../metrics/faithfulness/pixel_flipping.py | 2 +- .../faithfulness/region_perturbation.py | 2 +- quantus/metrics/faithfulness/road.py | 2 +- quantus/metrics/faithfulness/selectivity.py | 2 +- quantus/metrics/faithfulness/sensitivity_n.py | 2 +- quantus/metrics/faithfulness/sufficiency.py | 2 +- .../localisation/attribution_localisation.py | 2 +- quantus/metrics/localisation/auc.py | 2 +- quantus/metrics/localisation/focus.py | 2 +- quantus/metrics/localisation/pointing_game.py | 2 +- .../localisation/relevance_mass_accuracy.py | 2 +- .../localisation/relevance_rank_accuracy.py | 2 +- .../localisation/top_k_intersection.py | 2 +- quantus/metrics/randomisation/__init__.py | 9 +- ...ter_randomisation.py => efficient_mprt.py} | 70 +- quantus/metrics/randomisation/emprt.py | 758 ----------------- ...del_parameter_randomisation.py => mprt.py} | 63 +- quantus/metrics/randomisation/random_logit.py | 2 +- quantus/metrics/randomisation/smooth_mprt.py | 794 ++++++++++++++++++ quantus/metrics/robustness/avg_sensitivity.py | 2 +- quantus/metrics/robustness/consistency.py | 2 +- quantus/metrics/robustness/continuity.py | 2 +- .../robustness/local_lipschitz_estimate.py | 2 +- quantus/metrics/robustness/max_sensitivity.py | 2 +- tests/metrics/test_randomisation_metrics.py | 493 ++++++++++- tox.ini | 4 +- 45 files changed, 1581 insertions(+), 1031 deletions(-) rename quantus/metrics/randomisation/{efficient_model_parameter_randomisation.py => efficient_mprt.py} (92%) delete mode 100644 quantus/metrics/randomisation/emprt.py rename quantus/metrics/randomisation/{model_parameter_randomisation.py => mprt.py} (91%) create mode 100644 quantus/metrics/randomisation/smooth_mprt.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c77983dc8..687fc81c4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -89,7 +89,7 @@ It is possible to limit the scope of testing to specific sections of the codebas Faithfulness metrics using python3.9 (make sure the python versions match in your environment): ```bash -python3 -m tox run -e py39 -- -m faithfulness -s +python3 -m tox run -e py39 -- -m smprt -s ``` For a complete overview of the possible testing scopes, please refer to `pytest.ini`. diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 75d628ee3..102c0980e 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {'call_kwargs_empty': {}} + call_kwargs = {"call_kwargs_empty": {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -99,6 +99,9 @@ def evaluate( explain_funcs[method] = value explain_func = value + assert ( + explain_func_kwargs is not None + ), "Pass explain_func_kwargs as a dictionary." # Asserts. asserts.assert_explain_func(explain_func=explain_func) diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index 178e24a4f..74eb04819 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -275,149 +275,160 @@ def generate_tf_explanation( explanation: np.ndarray = np.zeros_like(inputs) - if method in constants.DEPRECATED_XAI_METHODS_TF: - warnings.warn( - f"Explanaiton method string {method} is deprecated. Use " - f"{constants.DEPRECATED_XAI_METHODS_TF[method]} instead.\n", - category=UserWarning, - ) - method = constants.DEPRECATED_XAI_METHODS_TF[method] - - if method == "VanillaGradients": - explainer = tf_explain.core.vanilla_gradients.VanillaGradients() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, + try: + if method in constants.DEPRECATED_XAI_METHODS_TF: + warnings.warn( + f"Explanation method string {method} is deprecated. Use " + f"{constants.DEPRECATED_XAI_METHODS_TF[method]} instead.\n", + category=UserWarning, ) - / 255 - ) + method = constants.DEPRECATED_XAI_METHODS_TF[method] - elif method == "IntegratedGradients": - n_steps = kwargs.get("n_steps", 10) - explainer = tf_explain.core.integrated_gradients.IntegratedGradients() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, n_steps=n_steps, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, + if method == "VanillaGradients": + explainer = tf_explain.core.vanilla_gradients.VanillaGradients() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 ) - / 255 - ) - elif method == "GradientsInput": - explainer = tf_explain.core.gradients_inputs.GradientsInputs() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, + elif method == "IntegratedGradients": + n_steps = kwargs.get("n_steps", 10) + explainer = tf_explain.core.integrated_gradients.IntegratedGradients() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, n_steps=n_steps, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 ) - / 255 - ) - elif method == "OcclusionSensitivity": - patch_size = kwargs.get("window", (1, *([4] * (inputs.ndim - 2))))[-1] - reduce_axes = kwargs.get("reduce_axes", (-1,)) - keepdims = kwargs.get("keepdims", False) - keep_dim = False - explainer = tf_explain.core.occlusion_sensitivity.OcclusionSensitivity() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), - model, - y, - patch_size=patch_size, - **xai_lib_kwargs, - ), - inputs, - targets, - ) - ), - dtype=float, + elif method == "GradientsInput": + explainer = tf_explain.core.gradients_inputs.GradientsInputs() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 ) - / 255 - ) - elif method == "GradCAM": - reduce_axes = kwargs.get("reduce_axes", (-1,)) - keepdims = kwargs.get("keepdims", False) - keep_dim = False - if "gc_layer" in kwargs: - xai_lib_kwargs["layer_name"] = kwargs["gc_layer"] - - explainer = tf_explain.core.grad_cam.GradCAM() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, + elif method == "OcclusionSensitivity": + patch_size = kwargs.get("window", (1, *([4] * (inputs.ndim - 2))))[-1] + reduce_axes = kwargs.get("reduce_axes", (-1,)) + keepdims = kwargs.get("keepdims", False) + keep_dim = False + explainer = tf_explain.core.occlusion_sensitivity.OcclusionSensitivity() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), + model, + y, + patch_size=patch_size, + **xai_lib_kwargs, + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 ) - / 255 - ) - elif method == "SmoothGrad": - - num_samples = kwargs.get("num_samples", 5) - noise = kwargs.get("noise", 0.1) - explainer = tf_explain.core.smoothgrad.SmoothGrad() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), - model, - y, - num_samples=num_samples, - noise=noise, - **xai_lib_kwargs, - ), - inputs, - targets, - ) - ), - dtype=float, + elif method == "GradCAM": + reduce_axes = kwargs.get("reduce_axes", (-1,)) + keepdims = kwargs.get("keepdims", False) + keep_dim = False + if "gc_layer" in kwargs: + xai_lib_kwargs["layer_name"] = kwargs["gc_layer"] + + explainer = tf_explain.core.grad_cam.GradCAM() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 ) - / 255 - ) - else: - raise KeyError( - f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS_TF}." - ) + elif method == "SmoothGrad": + + num_samples = kwargs.get("num_samples", 5) + noise = kwargs.get("noise", 0.1) + explainer = tf_explain.core.smoothgrad.SmoothGrad() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), + model, + y, + num_samples=num_samples, + noise=noise, + **xai_lib_kwargs, + ), + inputs, + targets, + ) + ), + dtype=float, + ) + / 255 + ) + + else: + raise KeyError( + f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS_TF}." + ) + + except ValueError as e: + if "must be at least three-dimensional" in str(e): + # Handle the specific error here + warnings.warn( + "Input data must be at least three-dimensional for tf-explain methods. " + "Returning explanations with random uniform values of the same shape as inputs.\n", + UserWarning, + ) + raise ValueError assert 0 not in reduce_axes, ( "Reduction over batch_axis is not available, please do not " @@ -430,7 +441,8 @@ def generate_tf_explanation( reduce_axes = {"axis": tuple(reduce_axes), "keepdims": keepdims} - # Prevent attribution summation for 2D-data. Recreate np.sum behavior when passing reduce_axes=(), i.e. no change. + # Prevent attribution summation for 2D-data. + # Recreate np.sum behavior when passing reduce_axes=(), i.e. no change. if (len(tuple(reduce_axes)) == 0) | (explanation.ndim < 3): return explanation diff --git a/quantus/helpers/constants.py b/quantus/helpers/constants.py index a60f02d4d..5c1d68b3f 100644 --- a/quantus/helpers/constants.py +++ b/quantus/helpers/constants.py @@ -62,8 +62,9 @@ "Effective Complexity": EffectiveComplexity, }, "Randomisation": { - "Model Parameter Randomisation Test": ModelParameterRandomisation, - "Efficient Model Parameter Randomisation Test": EfficientModelParameterRandomisation, + "MPRT": MPRT, + "Smooth MPRT": SmoothMPRT, + "Efficient MPRT": EfficientMPRT, "Random Logit": RandomLogit, }, "Axiomatic": { diff --git a/quantus/helpers/model/model_interface.py b/quantus/helpers/model/model_interface.py index 68513d89b..9d7e924e9 100644 --- a/quantus/helpers/model/model_interface.py +++ b/quantus/helpers/model/model_interface.py @@ -1,16 +1,22 @@ """This model implements the basics for the ModelInterface class.""" - # This file is part of Quantus. # Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. # Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . +import warnings +from importlib import util from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple, List, Union, Generator, TypeVar, Generic import numpy as np +if util.find_spec("tensorflow"): + import tensorflow as tf +if util.find_spec("torch"): + import torch + M = TypeVar("M") @@ -20,7 +26,7 @@ class ModelInterface(ABC, Generic[M]): def __init__( self, model: M, - channel_first: Optional[bool] = None, + channel_first: Optional[bool] = True, softmax: bool = False, model_predict_kwargs: Optional[Dict[str, Any]] = None, ): @@ -191,3 +197,23 @@ def random_layer_generator_length(self) -> int: Number of layers in model, which can be randomised. """ raise NotImplementedError + + @property + def get_ml_framework_name(self) -> str: + """ + Identify the framework of the underlying model (PyTorch or TensorFlow). + + Returns + ------- + str + A string indicating the framework ('PyTorch', 'TensorFlow', or 'Unknown'). + """ + if util.find_spec("torch"): + if isinstance(self.model, torch.nn.Module): + return "torch" + if util.find_spec("tensorflow"): + if isinstance(self.model, tf.keras.Model): + return "tensorflow" + else: + warnings.warn("Cannot identify ML framework of the given model.") + return "unknown" diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 2f2c9458d..784a9aa8d 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -243,17 +243,23 @@ def infer_channel_first(x: np.array) -> bool: Returns ------- - For 1d input: + For 1D input: + True for input shape (nr_batch, nr_features). + + For 2D input: True if input shape is (nr_batch, nr_channels, sequence_length). False if input shape is (nr_batch, sequence_length, nr_channels). An error is raised if the two last dimensions are equal. - For 2d input: + For 3D input: True if input shape is (nr_batch, nr_channels, img_width, img_height). False if input shape is (nr_batch, img_width, img_height, nr_channels). An error is raised if the three last dimensions are equal. """ - err_msg = "Ambiguous input shape. Cannot infer channel-first/channel-last order. Try setting the `channel_first` argument" + err_msg = ( + "Ambiguous input shape. Cannot infer channel-first/channel-last order. " + "Try setting the `channel_first` argument" + ) if len(np.shape(x)) == 2: return True @@ -275,11 +281,11 @@ def infer_channel_first(x: np.array) -> bool: else: raise ValueError( - "Only batched 1d and 2d multi-channel input dimensions supported." + "Only batched 2D and 3D multi-channel input dimensions supported." ) -def make_channel_first(x: np.array, channel_first=False): +def make_channel_first(x: np.array, channel_first: bool = False): """ Reshape batch to channel first. @@ -295,18 +301,20 @@ def make_channel_first(x: np.array, channel_first=False): """ if channel_first: return x - - if len(np.shape(x)) == 4: - return np.moveaxis(x, -1, -3) + if len(np.shape(x)) == 2: + return x elif len(np.shape(x)) == 3: return np.moveaxis(x, -1, -2) + elif len(np.shape(x)) == 4: + return np.moveaxis(x, -1, -3) + else: raise ValueError( - "Only batched 1d and 2d multi-channel input dimensions supported." + "Only batched 2D and 3D multi-channel input dimensions supported." ) -def make_channel_last(x: np.array, channel_first=True): +def make_channel_last(x: np.array, channel_first: bool = True): """ Reshape batch to channel last. @@ -322,14 +330,15 @@ def make_channel_last(x: np.array, channel_first=True): """ if not channel_first: return x - - if len(np.shape(x)) == 4: - return np.moveaxis(x, -3, -1) + if len(np.shape(x)) == 2: + return x elif len(np.shape(x)) == 3: return np.moveaxis(x, -2, -1) + elif len(np.shape(x)) == 4: + return np.moveaxis(x, -3, -1) else: raise ValueError( - "Only batched 1d and 2d multi-channel input dimensions supported." + "Only batched 2D and 3D multi-channel input dimensions supported." ) diff --git a/quantus/metrics/axiomatic/completeness.py b/quantus/metrics/axiomatic/completeness.py index b831d42ed..02877c56a 100644 --- a/quantus/metrics/axiomatic/completeness.py +++ b/quantus/metrics/axiomatic/completeness.py @@ -244,7 +244,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/axiomatic/input_invariance.py b/quantus/metrics/axiomatic/input_invariance.py index b1d1225a1..5310e6fda 100644 --- a/quantus/metrics/axiomatic/input_invariance.py +++ b/quantus/metrics/axiomatic/input_invariance.py @@ -224,7 +224,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/axiomatic/non_sensitivity.py b/quantus/metrics/axiomatic/non_sensitivity.py index 63877eda5..2c38eb523 100644 --- a/quantus/metrics/axiomatic/non_sensitivity.py +++ b/quantus/metrics/axiomatic/non_sensitivity.py @@ -244,7 +244,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index d2db4f494..999520f1e 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -256,7 +256,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ # Run deprecation warnings. warn.deprecation_warnings(kwargs) diff --git a/quantus/metrics/complexity/complexity.py b/quantus/metrics/complexity/complexity.py index f858e1257..19be155be 100644 --- a/quantus/metrics/complexity/complexity.py +++ b/quantus/metrics/complexity/complexity.py @@ -209,7 +209,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/complexity/effective_complexity.py b/quantus/metrics/complexity/effective_complexity.py index 3d8157423..8f2248703 100644 --- a/quantus/metrics/complexity/effective_complexity.py +++ b/quantus/metrics/complexity/effective_complexity.py @@ -212,7 +212,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/complexity/sparseness.py b/quantus/metrics/complexity/sparseness.py index ea2d43ade..9f7379bba 100644 --- a/quantus/metrics/complexity/sparseness.py +++ b/quantus/metrics/complexity/sparseness.py @@ -214,7 +214,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index b0ffe5a78..6e8e405d5 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -257,7 +257,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 282579d56..dad5fdaaa 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -240,7 +240,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 37717266b..2fe75df76 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -264,7 +264,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 61df3bf52..2b0650cdd 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -240,7 +240,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 76c13d9e2..0f3ca879a 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -236,7 +236,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index a1d48aa25..7efddf17e 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -257,7 +257,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, @@ -307,7 +307,7 @@ def evaluate_instance( y_pred = float(model.predict(x_input)[:, y]) inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred**2 + inv_pred = inv_pred ** 2 # Reshape attributions. a = a.flatten() diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index a5d173aea..a081b6ba7 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -236,7 +236,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 560283d6c..7a7370e25 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -257,7 +257,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 7e4da4106..5e25f0dda 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -237,7 +237,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 0281679bd..2058d05e4 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -247,7 +247,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 48ec7430a..6abafee83 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -256,7 +256,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index c931bb82e..4d2fb9ca9 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -230,7 +230,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/attribution_localisation.py b/quantus/metrics/localisation/attribution_localisation.py index 3cece8ad6..0cdad7374 100644 --- a/quantus/metrics/localisation/attribution_localisation.py +++ b/quantus/metrics/localisation/attribution_localisation.py @@ -221,7 +221,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/auc.py b/quantus/metrics/localisation/auc.py index 6ca02f822..803d6d009 100644 --- a/quantus/metrics/localisation/auc.py +++ b/quantus/metrics/localisation/auc.py @@ -201,7 +201,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/focus.py b/quantus/metrics/localisation/focus.py index f9522efa2..def438ee5 100644 --- a/quantus/metrics/localisation/focus.py +++ b/quantus/metrics/localisation/focus.py @@ -249,7 +249,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( diff --git a/quantus/metrics/localisation/pointing_game.py b/quantus/metrics/localisation/pointing_game.py index 328277155..804d82b15 100644 --- a/quantus/metrics/localisation/pointing_game.py +++ b/quantus/metrics/localisation/pointing_game.py @@ -214,7 +214,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/relevance_mass_accuracy.py b/quantus/metrics/localisation/relevance_mass_accuracy.py index 251d48e6a..808a77489 100644 --- a/quantus/metrics/localisation/relevance_mass_accuracy.py +++ b/quantus/metrics/localisation/relevance_mass_accuracy.py @@ -208,7 +208,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/relevance_rank_accuracy.py b/quantus/metrics/localisation/relevance_rank_accuracy.py index 9bd80d6ed..f09e5d2b9 100644 --- a/quantus/metrics/localisation/relevance_rank_accuracy.py +++ b/quantus/metrics/localisation/relevance_rank_accuracy.py @@ -210,7 +210,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/localisation/top_k_intersection.py b/quantus/metrics/localisation/top_k_intersection.py index 1d4095443..df32043ba 100644 --- a/quantus/metrics/localisation/top_k_intersection.py +++ b/quantus/metrics/localisation/top_k_intersection.py @@ -219,7 +219,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/randomisation/__init__.py b/quantus/metrics/randomisation/__init__.py index 690262860..b3ffa7d35 100644 --- a/quantus/metrics/randomisation/__init__.py +++ b/quantus/metrics/randomisation/__init__.py @@ -4,10 +4,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from quantus.metrics.randomisation.model_parameter_randomisation import ( - ModelParameterRandomisation, -) -from quantus.metrics.randomisation.efficient_model_parameter_randomisation import ( - EfficientModelParameterRandomisation, -) +from quantus.metrics.randomisation.mprt import MPRT +from quantus.metrics.randomisation.efficient_mprt import EfficientMPRT +from quantus.metrics.randomisation.smooth_mprt import SmoothMPRT from quantus.metrics.randomisation.random_logit import RandomLogit diff --git a/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py b/quantus/metrics/randomisation/efficient_mprt.py similarity index 92% rename from quantus/metrics/randomisation/efficient_model_parameter_randomisation.py rename to quantus/metrics/randomisation/efficient_mprt.py index e2967e2f7..9b6cc99ae 100644 --- a/quantus/metrics/randomisation/efficient_model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/efficient_mprt.py @@ -54,11 +54,11 @@ @final -class EfficientModelParameterRandomisation(Metric): +class EfficientMPRT(Metric): """ - Implementation of the Efficient Model Parameter Randomization Method by Hedström et. al., 2023. + Implementation of the Efficient MPRT by Hedström et al., 2023. - The Efficient Model Parameter Randomisation measures replaces the layer-by-layer pairwise comparison + The Efficient MPRT measures replaces the layer-by-layer pairwise comparison between e and ˆe of MPRT by instead computing the relative rise in explanation complexity using only two model states, i.e., the original- and fully randomised model version @@ -103,6 +103,10 @@ def __init__( """ Parameters ---------- + complexity_func: callable + A callable that computes the complexity of an explanation. + complexity_func_kwargs: dict, optional + Keyword arguments to be passed to complexity_func on call. similarity_func: callable Similarity function applied to compare input and perturbed input, default=correlation_spearman. layer_order: string @@ -111,12 +115,9 @@ def __init__( default="independent". seed: integer Seed used for the random generator, default=42. - return_average_correlation: boolean - Indicates whether to return one float per sample, computing the average - correlation coefficient across the layers for a given sample. - return_last_correlation: boolean - Indicates whether to return one float per sample, computing the explanation - correlation coefficient for the full model randomisation (not layer-wise) of a sample. + compute_extra_scores: boolean + Indicates if exta scores should be computed (and stored in a metric attrbute + (dict) called scores_extra. skip_layers: boolean Indicates if explanation similarity should be computed only once; between the original and fully randomised model, instead of in a layer-by-layer basis. @@ -142,7 +143,6 @@ def __init__( kwargs: optional Keyword arguments. """ - super().__init__( abs=abs, normalise=normalise, @@ -191,13 +191,13 @@ def __init__( warn.warn_parameterisation( metric_name=self.__class__.__name__, sensitive_params=( - "similarity metric 'similarity_func' and the order of " - "the layer randomisation 'layer_order'" + "the order of the layer randomisation 'layer_order' (we recommend " + "bottom-up randomisation and advice against top-down randomisation) " ), citation=( - "Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., and Kim, B. " - "'Sanity Checks for Saliency Maps.' arXiv preprint," - " arXiv:1810.073292v3 (2018)" + 'Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair' + ' the Model Parameter Randomisation Test." XAI in Action: Past, Present, ' + "and Future Applications. 2023." ), ) @@ -291,7 +291,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ # Run deprecation warnings. @@ -348,10 +348,6 @@ def __call__( debug=self.complexity_func_kwargs.get("debug", False), ) - def generate_y_batches(): - for batch in gen_batches(len(a_full_dataset), batch_size): - yield a_full_dataset[batch.start : batch.stop] - self.explanation_scores_by_layer = {} self.model_scores_by_layer = {} @@ -366,7 +362,7 @@ def generate_y_batches(): continue if l_ix == 0: - # Generate explanations on modified model in batches. + # Generate explanations on original model in batches. a_original_generator = self.generate_explanations( model.get_model(), x_full_dataset, y_full_dataset, batch_size ) @@ -374,7 +370,7 @@ def generate_y_batches(): # Compute the complexity of explanations of the original model. self.explanation_scores_by_layer["orig"] = [] for a_batch, a_batch_original in zip( - generate_y_batches(), a_original_generator + self.generate_a_batches(a_full_dataset), a_original_generator ): for a_instance, a_instance_original in zip( a_batch, a_batch_original @@ -396,7 +392,7 @@ def generate_y_batches(): score = entropy(a=y_pred, x=y_pred) self.model_scores_by_layer["orig"].append(score) - # Generate explanations on modified model in batches. + # Generate explanations on perturbed model in batches. a_perturbed_generator = self.generate_explanations( random_layer_model, x_full_dataset, y_full_dataset, batch_size ) @@ -404,7 +400,7 @@ def generate_y_batches(): # Compute the complexity of explanations of the perturbed model. self.explanation_scores_by_layer[layer_name] = [] for a_batch, a_batch_perturbed in zip( - generate_y_batches(), a_perturbed_generator + self.generate_a_batches(a_full_dataset), a_perturbed_generator ): for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): score = self.evaluate_instance( @@ -435,7 +431,6 @@ def generate_y_batches(): # Predict and save complexity scores of the perturbed model outputs. self.model_scores_by_layer[layer_name] = [] - print("!!", x_full_dataset.shape) y_preds = random_layer_model_wrapped.predict(x_full_dataset) for y_ix, y_pred in enumerate(y_preds): score = entropy(a=y_pred, x=y_pred) @@ -581,16 +576,37 @@ def generate_explanations( y_batch: np.ndarray, batch_size: int, ) -> Generator[np.ndarray, None, None]: - """Iterate over dataset in batches and generate explanations for complete dataset""" + """ + Iterate over dataset in batches and generate explanations for complete dataset. + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ for i in gen_batches(len(x_batch), batch_size): x = x_batch[i.start : i.stop] y = y_batch[i.start : i.stop] a = self.explain_batch(model, x, y) yield a + def generate_a_batches(self, a_full_dataset): + for batch in gen_batches(len(a_full_dataset), self.batch_size): + yield a_full_dataset[batch.start : batch.stop] + def evaluate_batch(self, *args, **kwargs): raise RuntimeError( - "`evaluate_batch` must never be called for `Model Parameter Randomisation`." + "`evaluate_batch` must never be called for `Model Parameter Randomisation Test`." ) def recompute_model_explanation_correlation_per_sample( diff --git a/quantus/metrics/randomisation/emprt.py b/quantus/metrics/randomisation/emprt.py deleted file mode 100644 index 6edbf0def..000000000 --- a/quantus/metrics/randomisation/emprt.py +++ /dev/null @@ -1,758 +0,0 @@ -"""This module contains the implementation of the enhanced Model Parameter Randomisation Test metric.""" - -# This file is part of Quantus. -# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. -# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. -# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . -# Quantus project URL: . - -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Union, - Collection, - Iterable, -) -import os -import numpy as np -from tqdm.auto import tqdm -import torch - -from quantus.helpers import asserts -from quantus.helpers import warn -from quantus.helpers import utils -from quantus.helpers.model.model_interface import ModelInterface -from quantus.functions.normalise_func import normalise_by_max - -# from quantus.functions import complexity_func -from quantus.metrics.base import Metric -from quantus.helpers.enums import ( - ModelType, - DataType, - ScoreDirection, - EvaluationCategory, -) - - -class eMPRT(Metric): - """ - Implementation of the NAME by AUTHOR et. al., 2023. - - INSERT DESC. - - References: - 1) INSERT SOURCE - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Enhanced Model Parameter Randomisation Test" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.RANDOMISATION - - def __init__( - self, - complexity_func: Optional[Callable] = None, - complexity_func_kwargs: Optional[dict] = None, - layer_order: str = "bottom_up", - nr_samples: Optional[int] = None, - seed: int = 42, - compute_delta: bool = True, - compute_rate_of_change: bool = True, - compute_delta_explanation_vs_model: bool = True, - compute_correlation: bool = True, - compute_last_complexity: bool = True, - return_delta_explanation_vs_model: bool = False, - return_fraction: bool = False, - return_rate_of_change: bool = True, - return_average_sample_score: bool = False, - return_correlation: bool = False, - return_last_complexity: bool = False, - return_delta_explanation: bool = False, - skip_layers: bool = False, - similarity_func: Optional[Callable] = None, - abs: bool = False, - normalise: bool = False, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = False, - aggregate_func: Callable = None, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - similarity_func: callable - Similarity function applied to compare input and perturbed input, default=correlation_spearman. - layer_order: string - Indicated whether the model is randomized cascadingly or independently. - Set order=top_down for cascading randomization, set order=independent for independent randomization, - default="independent". - seed: integer - Seed used for the random generator, default=42. - return_average_sample_score: boolean - Indicates whether return one float per sample, representing the average - correlation coefficient across the layers for that sample. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=True. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - if normalise_func is None: - normalise_func = normalise_by_max - - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - # Set seed for reproducibility. - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - torch.manual_seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - - # torch.backends.cudnn.benchmark = False - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.enabled = False - - # Save metric-specific attributes. - if complexity_func is None: - complexity_func = discrete_entropy - - if complexity_func_kwargs is None: - complexity_func_kwargs = {} - - if similarity_func is None: - similarity_func = similarity_func.correlation_spearman - - self.complexity_func = complexity_func - self.complexity_func_kwargs = complexity_func_kwargs - self.similarity_func = similarity_func - self.layer_order = layer_order - self.nr_samples = nr_samples - self.compute_delta = compute_delta - self.compute_rate_of_change = compute_rate_of_change - self.compute_delta_explanation_vs_model = compute_delta_explanation_vs_model - self.compute_correlation = compute_correlation - self.compute_last_complexity = compute_last_complexity - self.return_average_sample_score = return_average_sample_score - self.return_fraction = return_fraction - self.return_rate_of_change = return_rate_of_change - self.return_delta_explanation_vs_model = return_delta_explanation_vs_model - self.return_correlation = return_correlation - self.return_last_complexity = return_last_complexity - self.return_delta_explanation = return_delta_explanation - self.skip_layers = skip_layers - - # Asserts and warnings. - assert ( - sum( - [ - self.return_fraction, - self.return_average_sample_score, - self.return_correlation, - self.return_last_complexity, - self.return_delta_explanation, - self.return_delta_explanation_vs_model, - self.return_rate_of_change, - ] - ) - == 1 - ), "Set one of the possible 'return' arguments to True." - - asserts.assert_layer_order(layer_order=self.layer_order) - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params=( - "similarity metric 'similarity_func' and the order of " - "the layer randomisation 'layer_order'" - ), - citation=( - "Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., and Kim, B. " - "'Sanity Checks for Saliency Maps.' arXiv preprint," - " arXiv:1810.073292v3 (2018)" - ), - ) - - def __call__( - self, - model, - x_batch: np.array, - y_batch: np.array, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = False, - device: Optional[str] = None, - batch_size: int = 64, - custom_batch: Optional[Any] = None, - attributions_path: str = None, - **kwargs, - ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - The content of evaluation_scores will be appended to all_evaluation_scores (list) at the end of - the evaluation call. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} - """ - - # Run deprecation warnings. - warn.deprecation_warnings(kwargs) - warn.check_kwargs(kwargs) - - data = self.general_preprocess( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=None, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - model_predict_kwargs=model_predict_kwargs, - softmax=softmax, - device=device, - ) - - # Get model and data. - model = data["model"] - x_batch = data["x_batch"] - y_batch = data["y_batch"] - a_batch = data["a_batch"] - - # Get number of iterations from number of layers. - n_layers = len(list(model.get_random_layer_generator(order=self.layer_order))) - model_iterator = tqdm( - model.get_random_layer_generator(order=self.layer_order), - total=n_layers, - disable=not self.display_progressbar, - ) - - # Get the number of bins for discrete entropy calculation. - if "n_bins" not in self.complexity_func_kwargs: - self.find_n_bins( - a_batch=a_batch, - n_bins_default=self.complexity_func_kwargs.get("n_bins_default", 100), - min_n_bins=self.complexity_func_kwargs.get("min_n_bins", 10), - max_n_bins=self.complexity_func_kwargs.get("max_n_bins", 200), - debug=self.complexity_func_kwargs.get("debug", False), - ) - - # Compute the explanation_scores given uniformly sampled explanation. - if self.nr_samples is None: - self.nr_samples = len(a_batch) - - # Initialise arrays. - self.delta_explanation_scores = np.zeros((self.nr_samples)) - self.delta_model_scores = np.zeros((self.nr_samples)) - self.fraction_explanation_scores = np.zeros((self.nr_samples)) - self.fraction_model_scores = np.zeros((self.nr_samples)) - self.delta_explanation_vs_models = np.zeros((self.nr_samples)) - self.correlation_scores = np.zeros((self.nr_samples)) - self.rate_of_change_scores = np.zeros((self.nr_samples)) - self.explanation_scores = {} - self.model_scores = {} - - for l_ix, (layer_name, random_layer_model) in enumerate(model_iterator): - - if l_ix == 0: - - # Generate an explanation with perturbed model. - a_batch_original = self.explain_func( - model=model.get_model(), - inputs=x_batch, - targets=y_batch, - **self.explain_func_kwargs, - ) - - self.explanation_scores["orig"] = [] - for a_ix, a_ori in enumerate(a_batch_original): - score = self.evaluate_instance( - model=model, - x=x_batch[0], - y=None, - s=None, - a=a_ori, - ) - self.explanation_scores["orig"].append(score) - - # Compute entropy of the output layer. - self.model_scores["orig"] = [] - for y_ix, y_pred in enumerate(model.predict(x_batch)): - score = entropy(a=y_pred, x=y_pred) - self.model_scores["orig"].append(score) - - # Skip layers if computing delta. - if ( - self.skip_layers - and self.compute_delta - and (l_ix + 1) < len(model_iterator) - ): - continue - - # Score explanation complexity. - explanation_scores = [] - - # Generate an explanation with perturbed model. - a_batch_perturbed = self.explain_func( - model=random_layer_model, - inputs=x_batch, - targets=y_batch, - **self.explain_func_kwargs, - ) - - # Get id for storing data. - if attributions_path is not None: - savepath = os.path.join(attributions_path, f"{l_ix}-{layer_name}") - os.makedirs(savepath, exist_ok=True) - last_id = 0 - for fname in os.listdir(savepath): - if "original_attribution_" in fname: - id = ( - int(fname.split("original_attribution_")[1].split(".")[0]) - > last_id - ) - if id > last_id: - last_id = id - - batch_iterator = enumerate(zip(a_batch, a_batch_perturbed)) - for instance_id, (a_ix, a_perturbed) in batch_iterator: - score = self.evaluate_instance( - model=random_layer_model, - x=x_batch[0], - y=None, - s=None, - a=a_perturbed, - ) - explanation_scores.append(score) - - # Save data. - if attributions_path is not None: - np.save( - os.path.join(savepath, f"input_{last_id+instance_id}.npy"), - x_batch[instance_id], - ) - np.save( - os.path.join( - savepath, f"original_attribution_{last_id+instance_id}.npy" - ), - a_ix, - ) - np.save( - os.path.join( - savepath, f"perturbed_attribution_{last_id+instance_id}.npy" - ), - a_perturbed, - ) - - # Score the model complexity. - model_scores = [] - - # Wrap the model. - random_layer_model_wrapped = utils.get_wrapped_model( - model=random_layer_model, - channel_first=channel_first, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - ) - - # Predict and save scores. - y_preds = random_layer_model_wrapped.predict(x_batch) - for y_ix, y_pred in enumerate(y_preds): - score = entropy(a=y_pred, x=y_pred) - model_scores.append(score) - - # Save explanation_scores scores in a result dictionary. - self.explanation_scores[layer_name] = explanation_scores - self.model_scores[layer_name] = model_scores - - # Call post-processing. - self.custom_postprocess( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - ) - - # If compute correlation score (model and explanations) - if self.compute_correlation: - self.correlation_scores = ( - self.recompute_model_explanation_correlation_per_sample() - ) - - # If compute the last complexity score. - if self.compute_last_complexity: - self.last_complexity_scores = self.recompute_last_correlation_per_sample() - - # If compute delta score per sample (model and explanations). - if self.compute_delta: - - # Compute deltas for explanation scores. - scores = list(self.explanation_scores.values()) - self.delta_explanation_scores = [ - b - a for a, b in zip(scores[0], scores[-1]) - ] - - # Compute deltas for model scores. - scores = list(self.model_scores.values()) - self.delta_model_scores = [b - a for a, b in zip(scores[0], scores[-1])] - - # Compute fraction for explanation scores. - scores = list(self.explanation_scores.values()) - self.fraction_explanation_scores = [ - b / a if a != 0 else np.nan for a, b in zip(scores[0], scores[-1]) - ] # eMPRT original! - - # Compute fraction for explanation scores. - scores = list(self.model_scores.values()) - self.fraction_model_scores = [ - b / a if a != 0 else np.nan for a, b in zip(scores[0], scores[-1]) - ] - - # If compute delta skill score per sample (model and explanations). - if self.compute_delta_explanation_vs_model: - self.delta_explanation_vs_models = [ - b / a if a != 0 else np.nan - for a, b in zip( - self.fraction_model_scores, self.fraction_explanation_scores - ) - ] - - # If compute delta skill score per sample (model and explanations). - if self.compute_rate_of_change: - scores = list(self.explanation_scores.values()) - self.rate_of_change_scores = [ - (b - a) / a for a, b in zip(scores[0], scores[-1]) - ] - - # If return one score per sample. - if self.return_average_sample_score: - self.evaluation_scores = self.recompute_average_complexity_per_sample() - - # If return delta score per sample. - if self.return_fraction: - self.evaluation_scores = self.fraction_explanation_scores - - # If return delta score per sample. - if self.return_delta_explanation_vs_model: - self.evaluation_scores = self.delta_explanation_vs_models - - # If return delta score per sample. - if self.return_correlation: - self.evaluation_scores = self.correlation_scores - - if self.return_last_complexity: - self.evaluation_scores = self.last_complexity_scores - - if self.return_rate_of_change: - self.evaluation_scores = self.rate_of_change_scores - - # If return one aggregate score for all samples. - if self.return_aggregate: - self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] - - # Return all_evaluation_scores according to Quantus. - self.all_evaluation_scores.append(self.evaluation_scores) - - return self.evaluation_scores - - def evaluate_instance( - self, - model: ModelInterface, - x: Optional[np.ndarray], - y: Optional[np.ndarray], - a: Optional[np.ndarray], - s: Optional[np.ndarray], - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - i: integer - The evaluation instance. - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - if self.normalise: - a = self.normalise_func(a, **self.normalise_func_kwargs) - - if self.abs: - a = np.abs(a) - - # Compute distance measure. - return self.complexity_func(a=a, x=x, **self.complexity_func_kwargs) - - def custom_preprocess( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: Optional[np.ndarray], - a_batch: Optional[np.ndarray], - s_batch: np.ndarray, - custom_batch: Optional[np.ndarray], - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - custom_batch: any - Gives flexibility ot the user to use for evaluation, can hold any variable. - - Returns - ------- - None - """ - # Additional explain_func assert, as the one in general_preprocess() - # won't be executed when a_batch != None. - - asserts.assert_explain_func(explain_func=self.explain_func) - - def recompute_model_explanation_correlation_per_sample( - self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - - assert isinstance(self.explanation_scores, dict), ( - "To compute the correlation between model and explanation per sample for " - "enhanced Model Parameter Randomisation Test, 'explanation_scores' " - "must be of type dict." - ) - layer_length = len( - self.explanation_scores[list(self.explanation_scores.keys())[0]] - ) - explanation_scores: Dict[int, list] = { - sample: [] for sample in range(layer_length) - } - model_scores: Dict[int, list] = {sample: [] for sample in range(layer_length)} - - for sample in explanation_scores.keys(): - for layer in self.explanation_scores: - explanation_scores[sample].append( - float(self.explanation_scores[layer][sample]) - ) - model_scores[sample].append(float(self.model_scores[layer][sample])) - - corr_coeffs = [] - for sample in explanation_scores.keys(): - corr_coeffs.append( - self.similarity_func(model_scores[sample], explanation_scores[sample]) - ) - - return corr_coeffs - - def recompute_average_complexity_per_sample( - self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - - assert isinstance(self.explanation_scores, dict), ( - "To compute the average correlation coefficient per sample for " - "enhanced Model Parameter Randomisation Test, 'explanation_scores' " - "must be of type dict." - ) - layer_length = len( - self.explanation_scores[list(self.explanation_scores.keys())[0]] - ) - results: Dict[int, list] = {sample: [] for sample in range(layer_length)} - - for sample in results: - for layer in self.explanation_scores: - if layer == "orig": - continue - results[sample].append(float(self.explanation_scores[layer][sample])) - results[sample] = np.mean(results[sample]) - - corr_coeffs = list(results.values()) - - return corr_coeffs - - def recompute_last_correlation_per_sample( - self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: - - assert isinstance(self.explanation_scores, dict), ( - "To compute the last correlation coefficient per sample for " - "Model Parameter Randomisation Test, 'explanation_scores' " - "must be of type dict." - ) - corr_coeffs = list(self.explanation_scores.values())[-1] - - return corr_coeffs - - def find_n_bins( - self, - a_batch: np.array, - n_bins_default: int = 100, - min_n_bins: int = 10, - max_n_bins: int = 200, - debug: bool = True, - ) -> None: - - if self.normalise: - a_batch = self.normalise_func(a, **self.normalise_func_kwargs) - if self.abs: - a_batch = np.abs(a_batch) - - rule_name = self.complexity_func_kwargs.get("rule", None) - rule = RULES_N_BINS.get(rule_name) - - if debug: - print(f"\tMax and min value of a_batch=({a_batch.min()}, {a_batch.max()})") - - if not rule: - self.complexity_func_kwargs["n_bins"] = n_bins_default - if debug: - print(f"\tNo rule found, 'n_bins' set to 100.") - return None - - n_bins = rule(a_batch=a_batch) - n_bins = max(min(n_bins, max_n_bins), min_n_bins) - self.complexity_func_kwargs["n_bins"] = n_bins - - if debug: - print( - f"\tRule '{rule_name}' -> n_bins={n_bins} but with min={min_n_bins} and max={max_n_bins}, 'n_bins' set to {self.complexity_func_kwargs['n_bins']}." - ) diff --git a/quantus/metrics/randomisation/model_parameter_randomisation.py b/quantus/metrics/randomisation/mprt.py similarity index 91% rename from quantus/metrics/randomisation/model_parameter_randomisation.py rename to quantus/metrics/randomisation/mprt.py index ac814b269..a11468765 100644 --- a/quantus/metrics/randomisation/model_parameter_randomisation.py +++ b/quantus/metrics/randomisation/mprt.py @@ -7,6 +7,7 @@ # Quantus project URL: . import sys +import warnings from typing import ( Any, Callable, @@ -43,13 +44,13 @@ @final -class ModelParameterRandomisation(Metric): +class MPRT(Metric): """ - Implementation of the Model Parameter Randomisation Method by Adebayo et. al., 2018. + Implementation of the Model Parameter Randomisation Test (MPRT) by Adebayo et al., 2018. - The Model Parameter Randomisation measures the distance between the original attribution and a newly computed - attribution throughout the process of cascadingly/independently randomizing the model parameters of one layer - at a time. + The MPRT measures the distance between the original attribution and a newly computed + attribution throughout the process of cascadingly/ independently randomizing the model + parameters of one layer at a time. Assumptions: - In the original paper multiple distance measures are taken: Spearman rank correlation (with and without abs), @@ -277,7 +278,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ # Run deprecation warnings. @@ -316,10 +317,6 @@ def __call__( # Set property to False, so we display only 1 pbar. self._display_progressbar = False - def generate_y_batches(): - for batch in gen_batches(len(a_full_dataset), batch_size): - yield a_full_dataset[batch.start : batch.stop] - with pbar as pbar: for l_ix, (layer_name, random_layer_model) in enumerate( model.get_random_layer_generator(order=self.layer_order, seed=self.seed) @@ -332,7 +329,7 @@ def generate_y_batches(): if l_ix == 0: - # Generate explanations on modified model in batches. + # Generate explanations on original model in batches. a_original_generator = self.generate_explanations( model.get_model(), x_full_dataset, y_full_dataset, batch_size ) @@ -340,7 +337,7 @@ def generate_y_batches(): # Compute the similarity of explanations of the original model. self.evaluation_scores["original"] = [] for a_batch, a_batch_original in zip( - generate_y_batches(), a_original_generator + self.generate_a_batches(a_full_dataset), a_original_generator ): for a_instance, a_instance_original in zip( a_batch, a_batch_original @@ -359,14 +356,14 @@ def generate_y_batches(): self.evaluation_scores[layer_name] = [] - # Generate explanations on modified model in batches. + # Generate explanations on perturbed model in batches. a_perturbed_generator = self.generate_explanations( random_layer_model, x_full_dataset, y_full_dataset, batch_size ) # Compute the similarity of explanations of the perturbed model. for a_batch, a_batch_perturbed in zip( - generate_y_batches(), a_perturbed_generator + self.generate_a_batches(a_full_dataset), a_perturbed_generator ): for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): score = self.evaluate_instance( @@ -452,8 +449,6 @@ def evaluate_instance( Parameters ---------- - i: integer - The evaluation instance. model: ModelInterface A ModelInteface that is subject to explanation. x: np.ndarray @@ -522,14 +517,46 @@ def generate_explanations( y_batch: np.ndarray, batch_size: int, ) -> Generator[np.ndarray, None, None]: - """Iterate over dataset in batches and generate explanations for complete dataset""" + """ + Iterate over dataset in batches and generate explanations for complete dataset. + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ for i in gen_batches(len(x_batch), batch_size): x = x_batch[i.start : i.stop] y = y_batch[i.start : i.stop] a = self.explain_batch(model, x, y) yield a + def generate_a_batches(self, a_full_dataset): + for batch in gen_batches(len(a_full_dataset), self.batch_size): + yield a_full_dataset[batch.start : batch.stop] + def evaluate_batch(self, *args, **kwargs): raise RuntimeError( - "`evaluate_batch` must never be called for `Model Parameter Randomisation`." + "`evaluate_batch` must never be called for `Model Parameter Randomisation Test`." + ) + + +@final +class ModelParameterRandomisation(MPRT): + def __init__(self, *args, **kwargs): + warnings.warn( + "ModelParameterRandomisation has been renamed to MPRT and will be removed in future releases. " + "Please use MPRT instead. This change is effective from Quantus version 0.5.0.", + DeprecationWarning, ) + super().__init__(*args, **kwargs) diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index b5e9c9e0d..cc42db9bc 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -220,7 +220,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py new file mode 100644 index 000000000..b4342c705 --- /dev/null +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -0,0 +1,794 @@ +"""This module contains the implementation of the Smooth Model Parameter Randomisation Test metric.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +import sys +import warnings +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Union, + Tuple, + Generator, + Iterable, +) +from importlib import util + +import numpy as np +import quantus +from tqdm.auto import tqdm +from sklearn.utils import gen_batches +from scipy import stats + +from quantus.functions.similarity_func import correlation_spearman +from quantus.helpers import asserts, warn, utils +from quantus.helpers.enums import ( + DataType, + EvaluationCategory, + ModelType, + ScoreDirection, +) +from quantus.helpers.model.model_interface import ModelInterface +from quantus.metrics.base import Metric + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +if util.find_spec("torch"): + import torch + + +@final +class SmoothMPRT(Metric): + """ + Implementation of the Smooth MPRT by Hedström et al., 2023. + + The Sampling Model Parameter Randomisation measures the distance between the original attribution and a newly computed + attribution throughout the process of cascadingly/independently randomizing the model parameters of one layer + at a time. + + References: + 1) Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair the Model Parameter + Randomisation Test." XAI in Action: Past, Present, and Future Applications. 2023. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Smooth Model Parameter Randomisation Test" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.LOWER + evaluation_category = EvaluationCategory.RANDOMISATION + + def __init__( + self, + similarity_func: Optional[Callable] = None, + layer_order: str = "bottom_up", + seed: int = 42, + nr_samples: int = 50, + noise_magnitude: float = 0.1, + return_average_correlation: bool = False, + return_last_correlation: bool = False, + skip_layers: bool = False, + abs: bool = True, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input, default=correlation_spearman. + layer_order: string + Indicated whether the model is randomized cascadingly or independently. + Set order=top_down for cascading randomization, set order=independent for independent randomization, + default="independent". + seed: integer + Seed used for the random generator, default=42. + nr_samples: integer + The number of samples used to compute the average (denoised) explanations, default=50. + The default value is set based on ImageNet experiment in the reference paper. + Please update the value according to your use case. + noise_magnitude: float + The magnitude of the noise added to the input, default=0.1. + The default value is set based on ImageNet experiment in the reference paper. + Please update the value according to your use case. + return_average_correlation: boolean + Indicates whether to return one float per sample, computing the average + correlation coefficient across the layers for a given sample. + return_last_correlation: boolean + Indicates whether to return one float per sample, computing the explanation + correlation coefficient for the full model randomisation (not layer-wise) of a sample. + skip_layers: boolean + Indicates if explanation similarity should be computed only once; between the + original and fully randomised model, instead of in a layer-by-layer basis. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=True. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if similarity_func is None: + similarity_func = correlation_spearman + self.similarity_func = similarity_func + self.layer_order = layer_order + self.seed = seed + self.nr_samples = nr_samples + self.noise_magnitude = noise_magnitude + self.return_average_correlation = return_average_correlation + self.return_last_correlation = return_last_correlation + self.skip_layers = skip_layers + + # Results are returned/saved as a dictionary not like in the super-class as a list. + self.evaluation_scores = {} + + # Asserts and warnings. + if self.return_average_correlation and self.return_last_correlation: + raise ValueError( + f"Both 'return_average_correlation' and 'return_last_correlation' cannot be set to 'True'. " + f"Set both to 'False' or one of the attributes to 'True'." + ) + if self.return_average_correlation and self.skip_layers: + raise ValueError( + f"Both 'return_average_correlation' and 'skip_layers' cannot be set to 'True'. " + f"You need to calculate the explanation correlation at all layers in order " + f"to compute the average correlation coefficient on all layers." + ) + asserts.assert_layer_order(layer_order=self.layer_order) + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "similarity metric 'similarity_func' and the order of " + "the layer randomisation 'layer_order', the number of samples 'nr_samples', and" + " the magnitude of noise 'noise_magnitude' " + ), + citation=( + 'Hedström, Anna, et al. "Sanity Checks Revisited: An Exploration to Repair' + ' the Model Parameter Randomisation Test." XAI in Action: Past, Present, ' + "and Future Applications. 2023." + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + **kwargs, + ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + The content of evaluation_scores will be appended to all_evaluation_scores (list) at the end of + the evaluation call. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + + # Run deprecation warnings. + warn.deprecation_warnings(kwargs) + warn.check_kwargs(kwargs) + self.batch_size = batch_size + self.device = device + + if not isinstance(channel_first, bool): # None is not a boolean instance. + self.channel_first = utils.infer_channel_first(x_batch) + + data = self.general_preprocess( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + model_predict_kwargs=model_predict_kwargs, + softmax=softmax, + device=device, + ) + model: ModelInterface = data["model"] # type: ignore + # Here _batch refers to full dataset. + x_full_dataset = data["x_batch"] + y_full_dataset = data["y_batch"] + a_full_dataset = data["a_batch"] + + # Results are returned/saved as a dictionary not as a list as in the super-class. + self.evaluation_scores = {} + + # Get number of iterations from number of layers. + n_layers = model.random_layer_generator_length + pbar = tqdm( + total=n_layers * len(x_full_dataset), disable=not self.display_progressbar + ) + if self.display_progressbar: + # Set property to False, so we display only 1 pbar. + self._display_progressbar = False + + with pbar as pbar: + for l_ix, (layer_name, random_layer_model) in enumerate( + model.get_random_layer_generator(order=self.layer_order, seed=self.seed) + ): + pbar.desc = layer_name + + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < n_layers: + continue + + if l_ix == 0: + + # Generate explanations on original model in batches. + a_original_generator = self.generate_explanations( + model.get_model(), + x_full_dataset, + y_full_dataset, + **kwargs, + ) + + # Compute the similarity of explanations of the original model. + self.evaluation_scores["original"] = [] + for a_batch, a_batch_original in zip( + self.generate_a_batches(a_full_dataset), a_original_generator + ): + for a_instance, a_instance_original in zip( + a_batch, a_batch_original + ): + score = self.evaluate_instance( + model=model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_original, + ) + # Save similarity scores in a result dictionary. + self.evaluation_scores["original"].append(score) + pbar.update(1) + + self.evaluation_scores[layer_name] = [] + + # Generate explanations on perturbed model in batches. + a_perturbed_generator = self.generate_explanations( + random_layer_model, + x_full_dataset, + y_full_dataset, + **kwargs, + ) + + # Compute the similarity of explanations of the perturbed model. + for a_batch, a_batch_perturbed in zip( + self.generate_a_batches(a_full_dataset), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_perturbed, + ) + self.evaluation_scores[layer_name].append(score) + pbar.update(1) + + if self.return_average_correlation: + self.evaluation_scores = self.recompute_average_correlation_per_sample() + + elif self.return_last_correlation: + self.evaluation_scores = self.recompute_last_correlation_per_sample() + + if self.return_aggregate: + assert self.return_average_correlation or self.return_last_correlation, ( + "Set 'return_average_correlation' or 'return_last_correlation'" + " to True in order to compute the aggregate evaluation results." + ) + self.evaluation_scores = [self.aggregate_func(self.evaluation_scores)] + + # Return all_evaluation_scores according to Quantus. + self.all_evaluation_scores.append(self.evaluation_scores) + + return self.evaluation_scores + + def recompute_average_correlation_per_sample( + self, + ) -> List[float]: + + assert isinstance(self.evaluation_scores, dict), ( + "To compute the average correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'evaluation_scores' " + "must be of type dict." + ) + layer_length = len( + self.evaluation_scores[list(self.evaluation_scores.keys())[0]] + ) + results: Dict[int, list] = {sample: [] for sample in range(layer_length)} + + for sample in results: + for layer in self.evaluation_scores: + if layer == "orig": + continue + results[sample].append(float(self.evaluation_scores[layer][sample])) + results[sample] = np.mean(results[sample]) + + corr_coeffs = list(results.values()) + + return corr_coeffs + + def recompute_last_correlation_per_sample( + self, + ) -> List[float]: + + assert isinstance(self.evaluation_scores, dict), ( + "To compute the last correlation coefficient per sample for " + "enhanced Model Parameter Randomisation Test, 'evaluation_scores' " + "must be of type dict." + ) + # Return the correlation coefficient of the fully randomised model + # (excluding the non-randomised correlation). + corr_coeffs = list(self.evaluation_scores.values())[-1] + corr_coeffs = [float(c) for c in corr_coeffs] + return corr_coeffs + + def evaluate_instance( + self, + model: ModelInterface, + x: Optional[np.ndarray], + y: Optional[np.ndarray], + a: Optional[np.ndarray], + s: Optional[np.ndarray], + a_perturbed: Optional[np.ndarray] = None, + ) -> float: + """ + Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + + Parameters + ---------- + i: integer + The evaluation instance. + model: ModelInterface + A ModelInteface that is subject to explanation. + x: np.ndarray + The input to be evaluated on an instance-basis. + y: np.ndarray + The output to be evaluated on an instance-basis. + a: np.ndarray + The explanation to be evaluated on an instance-basis. + s: np.ndarray + The segmentation to be evaluated on an instance-basis. + a_perturbed: np.ndarray + The perturbed attributions. + + Returns + ------- + float + The evaluation results. + """ + # Flatten the arrays for comparison and check constancy. + a_flat = a.flatten() + a_perturbed_flat = a_perturbed.flatten() + + if np.array_equal(a_flat, a_perturbed_flat): + warnings.warn( + "The arrays 'a_perturbed' and 'a' are identical. " + "Returning a similarity measure of 1.", + UserWarning, + ) + return 1.0 + + # Check if either array is constant + if np.all(a_flat == a_flat[0]) or np.all( + a_perturbed_flat == a_perturbed_flat[0] + ): + warnings.warn( + "One of the input arrays is constant; " + "the correlation coefficient is not defined.", + UserWarning, + ) + return 1.0 # or some other default value + + # Compute similarity measure + try: + return self.similarity_func(a_perturbed_flat, a_flat) + except stats._warnings_errors.ConstantInputWarning: + warnings.warn( + "Encountered constant input in similarity measure calculation.", + UserWarning, + ) + return 1.0 + + # Compute similarity measure. + return self.similarity_func(a_perturbed_flat, a_flat) + + def custom_preprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray], + **kwargs, + ) -> Optional[Dict[str, np.ndarray]]: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + kwargs: + Unused. + Returns + ------- + None + """ + # Additional explain_func assert, as the one in general_preprocess() + # won't be executed when a_batch != None. + asserts.assert_explain_func(explain_func=self.explain_func) + if a_batch is not None: # Just to silence mypy warnings + return None + + a_batch_chunks = [] + for a_chunk in self.generate_explanations(model, x_batch, y_batch): + a_batch_chunks.extend(a_chunk) + return dict(a_batch=np.asarray(a_batch_chunks)) + + def generate_explanations( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + **kwargs, + ) -> Generator[np.ndarray, None, None]: + """ + Iterate over dataset in batches and generate explanations for complete dataset. + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ + for i in gen_batches(len(x_batch), self.batch_size): + x = x_batch[i.start : i.stop] + y = y_batch[i.start : i.stop] + a = self.explain_smooth_batch( + model=model, + x_batch=x, + y_batch=y, + **kwargs, + ) + yield a + + def generate_a_batches(self, a_full_dataset): + for batch in gen_batches(len(a_full_dataset), self.batch_size): + yield a_full_dataset[batch.start : batch.stop] + + def evaluate_batch(self, *args, **kwargs): + raise RuntimeError( + "`evaluate_batch` must never be called for `Model Parameter Randomisation Test`." + ) + + def explain_smooth_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + **kwargs, + ) -> np.ndarray: + """ + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + This method should primarily be used if you need to generate additional explanation + in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. + It will do few things: + - call model.shape_input (if ModelInterface instance was provided) + - unwrap model (if ModelInterface instance was provided) + - call explain_func + - expand attribution channel + - (optionally) normalize a_batch + - (optionally) take np.abs of a_batch + + Parameters + ------- + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ + if isinstance(model, ModelInterface): + # Sometimes the model is our wrapper, but sometimes raw Keras/Torch model. + x_batch = model.shape_input( + x=x_batch, + shape=x_batch.shape, + channel_first=True, + batched=True, + ) + model = model.get_model() + + # Set noise. + dims = tuple(range(1, x_batch.ndim)) + std = self.noise_magnitude * ( + x_batch.max(axis=dims, keepdims=True) + - x_batch.min(axis=dims, keepdims=True) + ) + a_batch_smooth = self.explain_smooth_batch_numpy( + model=model, x_batch=x_batch, y_batch=y_batch, std=std, **kwargs + ) + + a_batch_smooth = utils.expand_attribution_channel(a_batch_smooth, x_batch) + asserts.assert_attributions(x_batch=x_batch, a_batch=a_batch_smooth) + + # Normalise and take absolute values of the attributions, if configured during metric instantiation. + if self.normalise: + a_batch_smooth = self.normalise_func(a_batch_smooth) + + if self.abs: + a_batch_smooth = np.abs(a_batch_smooth) + + return a_batch_smooth + + def explain_smooth_batch_numpy( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + std: float, + **kwargs, + ) -> np.ndarray: + """ + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + This method should primarily be used if you need to generate additional explanation + in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. + It will do few things: + - call model.shape_input (if ModelInterface instance was provided) + - unwrap model (if ModelInterface instance was provided) + - call explain_func + - expand attribution channel + + Parameters + ------- + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + std : float + Standard deviation of the Gaussian noise. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ + a_batch_smooth = None + for n in range(self.nr_samples): + # the last epsilon is defined as zero to compute the true output, + # and have SmoothGrad w/ n_iter = 1 === gradient + if n == self.nr_samples - 1: + epsilon = np.zeros_like(x_batch) + else: + epsilon = np.random.randn(*x_batch.shape) * std + a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs) + if a_batch_smooth is None: + a_batch_smooth = a_batch / self.nr_samples + else: + a_batch_smooth += a_batch / self.nr_samples + + return a_batch_smooth + + def explain_smooth_batch_torch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + std: float, + **kwargs, + ) -> np.ndarray: + """ + Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + This method should primarily be used if you need to generate additional explanation + in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. + It will do few things: + - call model.shape_input (if ModelInterface instance was provided) + - unwrap model (if ModelInterface instance was provided) + - call explain_func + - expand attribution channel + + Parameters + ------- + model: + A model that is subject to explanation. + x_batch: + A np.ndarray which contains the input data that are explained. + y_batch: + A np.ndarray which contains the output labels that are explained. + std : float + Standard deviation of the Gaussian noise. + kwargs: optional, dict + List of hyperparameters. + + Returns + ------- + a_batch: + Batch of explanations ready to be evaluated. + """ + if not isinstance(x_batch, torch.Tensor): + x_batch = torch.Tensor(x_batch).to(self.device) + + if not isinstance(y_batch, torch.Tensor): + y_batch = torch.as_tensor(y_batch).to(self.device) + + a_batch_smooth = torch.zeros_like(x_batch) + for n in range(self.nr_samples): + # the last epsilon is defined as zero to compute the true output, + # and have SmoothGrad w/ n_iter = 1 === gradient + if n == self.nr_samples - 1: + epsilon = torch.zeros_like(x_batch) + else: + epsilon = torch.randn_like(x_batch) * std + + a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs) + + if a_batch_smooth is None: + a_batch_smooth = a_batch / self.nr_samples + else: + a_batch_smooth += a_batch / self.nr_samples + + return a_batch_smooth diff --git a/quantus/metrics/robustness/avg_sensitivity.py b/quantus/metrics/robustness/avg_sensitivity.py index 13b71cc7d..b44662e88 100644 --- a/quantus/metrics/robustness/avg_sensitivity.py +++ b/quantus/metrics/robustness/avg_sensitivity.py @@ -279,7 +279,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index f67c8638c..ea782d2db 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -218,7 +218,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/robustness/continuity.py b/quantus/metrics/robustness/continuity.py index cfb7fba34..9cf2b8911 100644 --- a/quantus/metrics/robustness/continuity.py +++ b/quantus/metrics/robustness/continuity.py @@ -262,7 +262,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/robustness/local_lipschitz_estimate.py b/quantus/metrics/robustness/local_lipschitz_estimate.py index 737b93bfa..cf04e6d2d 100644 --- a/quantus/metrics/robustness/local_lipschitz_estimate.py +++ b/quantus/metrics/robustness/local_lipschitz_estimate.py @@ -282,7 +282,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/quantus/metrics/robustness/max_sensitivity.py b/quantus/metrics/robustness/max_sensitivity.py index 8ab236386..34acfcfd4 100644 --- a/quantus/metrics/robustness/max_sensitivity.py +++ b/quantus/metrics/robustness/max_sensitivity.py @@ -275,7 +275,7 @@ def __call__( # Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) """ return super().__call__( model=model, diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index bef8ea483..524c6abfe 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -9,9 +9,10 @@ from quantus.functions.similarity_func import correlation_spearman, correlation_pearson from quantus.helpers.model.model_interface import ModelInterface from quantus.metrics.randomisation import ( - ModelParameterRandomisation, + MPRT, + EfficientMPRT, + SmoothMPRT, RandomLogit, - EfficientModelParameterRandomisation, ) @@ -365,7 +366,7 @@ def test_model_parameter_randomisation( if "exception" in expected: with pytest.raises(expected["exception"]): - scores_layers = ModelParameterRandomisation(**init_params)( + scores_layers = MPRT(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, @@ -374,7 +375,7 @@ def test_model_parameter_randomisation( ) return - scores = ModelParameterRandomisation(**init_params)( + scores = MPRT(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, @@ -408,13 +409,15 @@ def test_model_parameter_randomisation( lazy_fixture("almost_uniform_1d_no_abatch"), { "init": { - "num_classes": 10, + "layer_order": "top_down", + "similarity_func": correlation_spearman, "normalise": True, "disable_warnings": False, "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -428,13 +431,15 @@ def test_model_parameter_randomisation( lazy_fixture("load_mnist_images"), { "init": { - "num_classes": 10, + "layer_order": "top_down", + "similarity_func": correlation_spearman, "normalise": True, "disable_warnings": False, "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -447,15 +452,16 @@ def test_model_parameter_randomisation( lazy_fixture("load_1d_3ch_conv_model"), lazy_fixture("almost_uniform_1d_no_abatch"), { - "a_batch_generate": False, "init": { - "num_classes": 10, - "normalise": False, + "layer_order": "bottom_up", + "similarity_func": correlation_pearson, + "normalise": True, "disable_warnings": True, "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -468,15 +474,16 @@ def test_model_parameter_randomisation( lazy_fixture("load_mnist_model"), lazy_fixture("load_mnist_images"), { - "a_batch_generate": False, "init": { - "num_classes": 10, - "normalise": False, + "layer_order": "bottom_up", + "similarity_func": correlation_pearson, + "normalise": True, "disable_warnings": True, "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -485,18 +492,88 @@ def test_model_parameter_randomisation( }, {"min": -1.0, "max": 1.0}, ), + ( + lazy_fixture("load_mnist_model_tf"), + lazy_fixture("load_mnist_images_tf"), + { + "init": { + "layer_order": "top_down", + "similarity_func": correlation_spearman, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "VanillaGradients", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_1d_3ch_conv_model_tf"), + lazy_fixture("almost_uniform_1d_no_abatch_channel_last"), + { + "a_batch_generate": False, + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_pearson, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "VanillaGradients", + }, + }, + }, + {"exception": ValueError}, + ), + ( + lazy_fixture("load_mnist_model_tf"), + lazy_fixture("load_mnist_images_tf"), + { + "a_batch_generate": False, + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_pearson, + "normalise": True, + "disable_warnings": True, + "display_progressbar": False, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": { + "explain_func": explain, + "explain_func_kwargs": { + "method": "Gradient", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), ( lazy_fixture("load_1d_3ch_conv_model"), lazy_fixture("almost_uniform_1d_no_abatch"), { "init": { - "num_classes": 10, + "layer_order": "top_down", + "similarity_func": correlation_spearman, "normalise": True, "disable_warnings": True, "display_progressbar": True, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -510,13 +587,15 @@ def test_model_parameter_randomisation( lazy_fixture("load_mnist_images"), { "init": { - "num_classes": 10, + "layer_order": "independent", + "similarity_func": correlation_spearman, "normalise": True, "disable_warnings": True, "display_progressbar": True, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -530,13 +609,15 @@ def test_model_parameter_randomisation( lazy_fixture("titanic_dataset"), { "init": { - "num_classes": 2, + "layer_order": "independent", + "similarity_func": correlation_spearman, "normalise": True, "abs": True, "disable_warnings": True, + "nr_samples": 5, + "noise_magnitude": 0.1, }, "call": { - "softmax": True, "explain_func": explain, "explain_func_kwargs": { "method": "IntegratedGradients", @@ -544,26 +625,128 @@ def test_model_parameter_randomisation( }, }, }, - {"min": -1.0, "max": 1.1}, + {"min": -1.0, "max": 1.0}, ), ( lazy_fixture("titanic_model_tf"), lazy_fixture("titanic_dataset"), { "init": { - "num_classes": 2, + "layer_order": "independent", + "similarity_func": correlation_spearman, "normalise": True, "abs": True, "disable_warnings": True, - "similarity_func": correlation_pearson, + "nr_samples": 5, + "noise_magnitude": 0.1, }, - "call": {"softmax": True, "explain_func": explain_func_stub}, + "call": {"explain_func": explain_func_stub}, }, - {"min": -1.0, "max": 1.1}, + {"exception": ValueError}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "top_down", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": True, + "return_last_correlation": False, + "skip_layers": False, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": False, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "bottom_up", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + "nr_samples": 5, + "noise_magnitude": 0.1, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": False, + "return_last_correlation": True, + "skip_layers": True, + "nr_samples": 5, + "noise_magnitude": 1.0, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "layer_order": "independent", + "similarity_func": correlation_spearman, + "normalise": True, + "abs": True, + "disable_warnings": True, + "return_average_correlation": True, + "return_last_correlation": True, + "skip_layers": True, + "nr_samples": 5, + "noise_magnitude": 0.2, + }, + "call": {"explain_func": explain_func_stub}, + }, + {"exception": ValueError}, ), ], ) -def test_random_logit( +def test_smooth_model_parameter_randomisation( model: ModelInterface, data: np.ndarray, params: dict, @@ -590,19 +773,44 @@ def test_random_logit( a_batch = data["a_batch"] else: a_batch = None - scores = RandomLogit(**init_params)( + + if "exception" in expected: + with pytest.raises(expected["exception"]): + scores_layers = SmoothMPRT(**init_params)( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + **call_params, + ) + return + + scores = SmoothMPRT(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch, **call_params, ) - assert all( - expected["min"] <= s <= expected["max"] for s in scores - ), f"Test failed with scores {scores}." + + if isinstance(scores, dict): + for layer, scores_layer in scores.items(): + out_of_range_scores = [ + s for s in scores_layer if not (expected["min"] <= s <= expected["max"]) + ] + assert ( + not out_of_range_scores + ), f"Test failed for layer {layer}. Out of range scores: {out_of_range_scores}" + elif isinstance(scores, list): + out_of_range_scores = [ + s for s in scores if not (expected["min"] <= s <= expected["max"]) + ] + assert ( + not out_of_range_scores + ), f"Test failed. Out of range scores: {out_of_range_scores}" -@pytest.mark.emprt +@pytest.mark.randomisation @pytest.mark.parametrize( "model,data,params,expected", [ @@ -921,7 +1129,7 @@ def test_efficient_model_parameter_randomisation( if "exception" in expected: with pytest.raises(expected["exception"]): - scores_layers = EfficientModelParameterRandomisation(**init_params)( + scores_layers = EfficientMPRT(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, @@ -930,7 +1138,7 @@ def test_efficient_model_parameter_randomisation( ) return - scores = EfficientModelParameterRandomisation(**init_params)( + scores = EfficientMPRT(**init_params)( model=model, x_batch=x_batch, y_batch=y_batch, @@ -944,3 +1152,218 @@ def test_efficient_model_parameter_randomisation( assert ( not out_of_range_scores ), f"Test failed. Out of range scores: {out_of_range_scores}" + + +@pytest.mark.randomisation +@pytest.mark.parametrize( + "model,data,params,expected", + [ + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "init": { + "num_classes": 10, + "normalise": True, + "disable_warnings": False, + "display_progressbar": False, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "num_classes": 10, + "normalise": True, + "disable_warnings": False, + "display_progressbar": False, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "a_batch_generate": False, + "init": { + "num_classes": 10, + "normalise": False, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "a_batch_generate": False, + "init": { + "num_classes": 10, + "normalise": False, + "disable_warnings": True, + "display_progressbar": False, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_1d_3ch_conv_model"), + lazy_fixture("almost_uniform_1d_no_abatch"), + { + "init": { + "num_classes": 10, + "normalise": True, + "disable_warnings": True, + "display_progressbar": True, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("load_mnist_model"), + lazy_fixture("load_mnist_images"), + { + "init": { + "num_classes": 10, + "normalise": True, + "disable_warnings": True, + "display_progressbar": True, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "Saliency", + }, + }, + }, + {"min": -1.0, "max": 1.0}, + ), + ( + lazy_fixture("titanic_model_torch"), + lazy_fixture("titanic_dataset"), + { + "init": { + "num_classes": 2, + "normalise": True, + "abs": True, + "disable_warnings": True, + }, + "call": { + "softmax": True, + "explain_func": explain, + "explain_func_kwargs": { + "method": "IntegratedGradients", + "reduce_axes": (), + }, + }, + }, + {"min": -1.0, "max": 1.1}, + ), + ( + lazy_fixture("titanic_model_tf"), + lazy_fixture("titanic_dataset"), + { + "init": { + "num_classes": 2, + "normalise": True, + "abs": True, + "disable_warnings": True, + "similarity_func": correlation_pearson, + }, + "call": {"softmax": True, "explain_func": explain_func_stub}, + }, + {"min": -1.0, "max": 1.1}, + ), + ], +) +def test_random_logit( + model: ModelInterface, + data: np.ndarray, + params: dict, + expected: Union[float, dict, bool], +): + x_batch, y_batch = ( + data["x_batch"], + data["y_batch"], + ) + + init_params = params.get("init", {}) + call_params = params.get("call", {}) + + if params.get("a_batch_generate", True): + explain = call_params["explain_func"] + explain_func_kwargs = call_params.get("explain_func_kwargs", {}) + a_batch = explain( + model=model, + inputs=x_batch, + targets=y_batch, + **explain_func_kwargs, + ) + elif "a_batch" in data: + a_batch = data["a_batch"] + else: + a_batch = None + + if "exception" in expected: + with pytest.raises(expected["exception"]): + scores_layers = EfficientMPRT(**init_params)( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + **call_params, + ) + return + + scores = RandomLogit(**init_params)( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + **call_params, + ) + assert all( + expected["min"] <= s <= expected["max"] for s in scores + ), f"Test failed with scores {scores}." diff --git a/tox.ini b/tox.ini index 745a8d6f9..0dac0bb3d 100644 --- a/tox.ini +++ b/tox.ini @@ -41,8 +41,8 @@ description = Check the code style deps = flake8 commands = - python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=venv + python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=venv [testenv:type] description = Run type checking From ac1cc261a9d960714a26e825b1da6eb3cc49c5de Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 23 Nov 2023 12:57:11 +0100 Subject: [PATCH 06/11] fixed issues with explanation method --- quantus/functions/explanation_func.py | 274 ++++++++++---------- quantus/metrics/randomisation/__init__.py | 2 +- tests/metrics/test_randomisation_metrics.py | 12 +- 3 files changed, 143 insertions(+), 145 deletions(-) diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index 74eb04819..961bbe906 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -275,160 +275,150 @@ def generate_tf_explanation( explanation: np.ndarray = np.zeros_like(inputs) - try: - if method in constants.DEPRECATED_XAI_METHODS_TF: - warnings.warn( - f"Explanation method string {method} is deprecated. Use " - f"{constants.DEPRECATED_XAI_METHODS_TF[method]} instead.\n", - category=UserWarning, - ) - method = constants.DEPRECATED_XAI_METHODS_TF[method] - - if method == "VanillaGradients": - explainer = tf_explain.core.vanilla_gradients.VanillaGradients() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 - ) - - elif method == "IntegratedGradients": - n_steps = kwargs.get("n_steps", 10) - explainer = tf_explain.core.integrated_gradients.IntegratedGradients() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, n_steps=n_steps, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 + if method in constants.DEPRECATED_XAI_METHODS_TF: + warnings.warn( + f"Explanation method string {method} is deprecated. Use " + f"{constants.DEPRECATED_XAI_METHODS_TF[method]} instead.\n", + category=UserWarning, + ) + method = constants.DEPRECATED_XAI_METHODS_TF[method] + + if method == "VanillaGradients": + explainer = tf_explain.core.vanilla_gradients.VanillaGradients() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - elif method == "GradientsInput": - explainer = tf_explain.core.gradients_inputs.GradientsInputs() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 + elif method == "IntegratedGradients": + n_steps = kwargs.get("n_steps", 10) + explainer = tf_explain.core.integrated_gradients.IntegratedGradients() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, n_steps=n_steps, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - elif method == "OcclusionSensitivity": - patch_size = kwargs.get("window", (1, *([4] * (inputs.ndim - 2))))[-1] - reduce_axes = kwargs.get("reduce_axes", (-1,)) - keepdims = kwargs.get("keepdims", False) - keep_dim = False - explainer = tf_explain.core.occlusion_sensitivity.OcclusionSensitivity() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), - model, - y, - patch_size=patch_size, - **xai_lib_kwargs, - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 + elif method == "GradientsInput": + explainer = tf_explain.core.gradients_inputs.GradientsInputs() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - elif method == "GradCAM": - reduce_axes = kwargs.get("reduce_axes", (-1,)) - keepdims = kwargs.get("keepdims", False) - keep_dim = False - if "gc_layer" in kwargs: - xai_lib_kwargs["layer_name"] = kwargs["gc_layer"] - - explainer = tf_explain.core.grad_cam.GradCAM() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), model, y, **xai_lib_kwargs - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 + elif method == "OcclusionSensitivity": + patch_size = kwargs.get("window", (1, *([4] * (inputs.ndim - 2))))[-1] + reduce_axes = kwargs.get("reduce_axes", (-1,)) + keepdims = kwargs.get("keepdims", False) + keep_dim = False + explainer = tf_explain.core.occlusion_sensitivity.OcclusionSensitivity() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), + model, + y, + patch_size=patch_size, + **xai_lib_kwargs, + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - elif method == "SmoothGrad": - - num_samples = kwargs.get("num_samples", 5) - noise = kwargs.get("noise", 0.1) - explainer = tf_explain.core.smoothgrad.SmoothGrad() - explanation = ( - np.array( - list( - map( - lambda x, y: explainer.explain( - ([x], None), - model, - y, - num_samples=num_samples, - noise=noise, - **xai_lib_kwargs, - ), - inputs, - targets, - ) - ), - dtype=float, - ) - / 255 + elif method == "GradCAM": + reduce_axes = kwargs.get("reduce_axes", (-1,)) + keepdims = kwargs.get("keepdims", False) + keep_dim = False + if "gc_layer" in kwargs: + xai_lib_kwargs["layer_name"] = kwargs["gc_layer"] + + explainer = tf_explain.core.grad_cam.GradCAM() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), model, y, **xai_lib_kwargs + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - else: - raise KeyError( - f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS_TF}." + elif method == "SmoothGrad": + + num_samples = kwargs.get("num_samples", 5) + noise = kwargs.get("noise", 0.1) + explainer = tf_explain.core.smoothgrad.SmoothGrad() + explanation = ( + np.array( + list( + map( + lambda x, y: explainer.explain( + ([x], None), + model, + y, + num_samples=num_samples, + noise=noise, + **xai_lib_kwargs, + ), + inputs, + targets, + ) + ), + dtype=float, ) + / 255 + ) - except ValueError as e: - if "must be at least three-dimensional" in str(e): - # Handle the specific error here - warnings.warn( - "Input data must be at least three-dimensional for tf-explain methods. " - "Returning explanations with random uniform values of the same shape as inputs.\n", - UserWarning, - ) - raise ValueError + else: + raise KeyError( + f"To use the 'quantus.explain' method with tf-explain as a supporting library, " + f"specify a XAI method that is supported {constants.AVAILABLE_XAI_METHODS_TF}." + ) assert 0 not in reduce_axes, ( "Reduction over batch_axis is not available, please do not " diff --git a/quantus/metrics/randomisation/__init__.py b/quantus/metrics/randomisation/__init__.py index b3ffa7d35..f20808276 100644 --- a/quantus/metrics/randomisation/__init__.py +++ b/quantus/metrics/randomisation/__init__.py @@ -4,7 +4,7 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -from quantus.metrics.randomisation.mprt import MPRT +from quantus.metrics.randomisation.mprt import MPRT, ModelParameterRandomisation from quantus.metrics.randomisation.efficient_mprt import EfficientMPRT from quantus.metrics.randomisation.smooth_mprt import SmoothMPRT from quantus.metrics.randomisation.random_logit import RandomLogit diff --git a/tests/metrics/test_randomisation_metrics.py b/tests/metrics/test_randomisation_metrics.py index 524c6abfe..95dd75e67 100644 --- a/tests/metrics/test_randomisation_metrics.py +++ b/tests/metrics/test_randomisation_metrics.py @@ -118,6 +118,7 @@ def explain_func_stub(*args, **kwargs): "display_progressbar": False, }, "call": { + "batch_size": 2, "explain_func": explain, "explain_func_kwargs": { "method": "VanillaGradients", @@ -1057,7 +1058,10 @@ def test_smooth_model_parameter_randomisation( "compute_extra_scores": False, "skip_layers": False, }, - "call": {"explain_func": explain_func_stub}, + "call": { + "explain_func": explain_func_stub, + "batch_size": 2, + }, }, {"min": -1000000000, "max": 1000000000}, ), @@ -1190,6 +1194,7 @@ def test_efficient_model_parameter_randomisation( }, "call": { "softmax": True, + "batch_size": 2, "explain_func": explain, "explain_func_kwargs": { "method": "Saliency", @@ -1312,7 +1317,10 @@ def test_efficient_model_parameter_randomisation( "disable_warnings": True, "similarity_func": correlation_pearson, }, - "call": {"softmax": True, "explain_func": explain_func_stub}, + "call": { + "softmax": True, + "explain_func": explain_func_stub, + }, }, {"min": -1.0, "max": 1.1}, ), From a95a614b4993bf32a57bf56aafbb1ba6b599980f Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 23 Nov 2023 13:19:21 +0100 Subject: [PATCH 07/11] update tox to exclude other packages --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 0dac0bb3d..b1d3c9b01 100644 --- a/tox.ini +++ b/tox.ini @@ -41,8 +41,8 @@ description = Check the code style deps = flake8 commands = - python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=venv - python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=venv + python3 -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=venv,.tox + python3 -m flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude=venv,.tox [testenv:type] description = Run type checking From 837824eb6d4d169ff6a247b6a0a1ac7e057244cb Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 23 Nov 2023 13:50:31 +0100 Subject: [PATCH 08/11] typing fixes --- README.md | 5 +---- quantus/helpers/model/model_interface.py | 1 + .../metrics/randomisation/efficient_mprt.py | 14 ++++++------ quantus/metrics/randomisation/mprt.py | 2 +- quantus/metrics/randomisation/smooth_mprt.py | 22 +------------------ 5 files changed, 11 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index c60ebe61e..3d86969fb 100644 --- a/README.md +++ b/README.md @@ -23,14 +23,11 @@ _Quantus is currently under active development so carefully note the Quantus rel ## News and Highlights! :rocket: -- Released a new version [v0.4.3](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) +- Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) - Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html) - Offers more than **30+ metrics in 6 categories** for XAI evaluation - Supports different data types (image, time-series, tabular, NLP next up!) and models (PyTorch, TensorFlow) - Extended built-in support for explanation methods ([captum](https://captum.ai/), [tf-explain](https://tf-explain.readthedocs.io/en/latest/) and [zennit](https://github.com/chr5tphr/zennit)) -- New optimisations to help speed up computation, see API reference [here](https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.base_batched.html) - -See [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) for the latest release(s). ## Citation diff --git a/quantus/helpers/model/model_interface.py b/quantus/helpers/model/model_interface.py index 9d7e924e9..ebd040df6 100644 --- a/quantus/helpers/model/model_interface.py +++ b/quantus/helpers/model/model_interface.py @@ -217,3 +217,4 @@ def get_ml_framework_name(self) -> str: else: warnings.warn("Cannot identify ML framework of the given model.") return "unknown" + return "" diff --git a/quantus/metrics/randomisation/efficient_mprt.py b/quantus/metrics/randomisation/efficient_mprt.py index 9b6cc99ae..1697a328f 100644 --- a/quantus/metrics/randomisation/efficient_mprt.py +++ b/quantus/metrics/randomisation/efficient_mprt.py @@ -216,7 +216,7 @@ def __call__( device: Optional[str] = None, batch_size: int = 64, **kwargs, - ) -> Union[List[float], float, Dict[str, List[float]], Collection[Any]]: + ) -> List[Any]: """ This implementation represents the main logic of the metric and makes the class object callable. It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), @@ -348,8 +348,8 @@ def __call__( debug=self.complexity_func_kwargs.get("debug", False), ) - self.explanation_scores_by_layer = {} - self.model_scores_by_layer = {} + self.explanation_scores_by_layer: Dict[str, List[float]] = {} + self.model_scores_by_layer: Dict[str, List[float]] = {} with pbar as pbar: for l_ix, (layer_name, random_layer_model) in enumerate( @@ -611,7 +611,7 @@ def evaluate_batch(self, *args, **kwargs): def recompute_model_explanation_correlation_per_sample( self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + ) -> List[Union[float, Any]]: assert isinstance(self.explanation_scores_by_layer, dict), ( "To compute the correlation between model and explanation per sample for " @@ -647,7 +647,7 @@ def recompute_model_explanation_correlation_per_sample( def recompute_average_complexity_per_sample( self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + ) -> List[float]: assert isinstance(self.explanation_scores_by_layer, dict), ( "To compute the average correlation coefficient per sample for " @@ -670,13 +670,13 @@ def recompute_average_complexity_per_sample( ) results[sample] = np.mean(results[sample]) - corr_coeffs = list(results.values()) + corr_coeffs = np.array(list(results.values())).flatten().tolist() return corr_coeffs def recompute_last_correlation_per_sample( self, - ) -> Union[List[List[Any]], Dict[int, List[Any]]]: + ) -> List[float]: assert isinstance(self.explanation_scores_by_layer, dict), ( "To compute the last correlation coefficient per sample for " diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py index a11468765..edf418cfe 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/mprt.py @@ -416,7 +416,7 @@ def recompute_average_correlation_per_sample( results[sample].append(float(self.evaluation_scores[layer][sample])) results[sample] = np.mean(results[sample]) - corr_coeffs = list(results.values()) + corr_coeffs = np.array(list(results.values())).flatten().tolist() return corr_coeffs diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index b4342c705..57e548aed 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -442,7 +442,7 @@ def recompute_average_correlation_per_sample( results[sample].append(float(self.evaluation_scores[layer][sample])) results[sample] = np.mean(results[sample]) - corr_coeffs = list(results.values()) + corr_coeffs = np.array(list(results.values())).flatten().tolist() return corr_coeffs @@ -499,26 +499,6 @@ def evaluate_instance( a_flat = a.flatten() a_perturbed_flat = a_perturbed.flatten() - if np.array_equal(a_flat, a_perturbed_flat): - warnings.warn( - "The arrays 'a_perturbed' and 'a' are identical. " - "Returning a similarity measure of 1.", - UserWarning, - ) - return 1.0 - - # Check if either array is constant - if np.all(a_flat == a_flat[0]) or np.all( - a_perturbed_flat == a_perturbed_flat[0] - ): - warnings.warn( - "One of the input arrays is constant; " - "the correlation coefficient is not defined.", - UserWarning, - ) - return 1.0 # or some other default value - - # Compute similarity measure try: return self.similarity_func(a_perturbed_flat, a_flat) except stats._warnings_errors.ConstantInputWarning: From 5455df2d222f2ce19d7433fc5f444387e9eadbfa Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 23 Nov 2023 14:02:32 +0100 Subject: [PATCH 09/11] added test cases for new functions --- quantus/functions/complexity_func.py | 6 ++-- tests/functions/__init__.py | 0 tests/functions/test_complexity_func.py | 31 +++++++++++++++++++ .../test_explanation_func.py | 0 .../{helpers => functions}/test_loss_func.py | 0 .../test_mosaic_func.py | 0 tests/functions/test_n_bins_func.py | 28 +++++++++++++++++ .../{helpers => functions}/test_norm_func.py | 0 .../test_normalise_func.py | 0 .../test_perturb_func.py | 0 .../test_pytorch_model.py | 0 .../test_similarity_func.py | 0 12 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 tests/functions/__init__.py create mode 100644 tests/functions/test_complexity_func.py rename tests/{helpers => functions}/test_explanation_func.py (100%) rename tests/{helpers => functions}/test_loss_func.py (100%) rename tests/{helpers => functions}/test_mosaic_func.py (100%) create mode 100644 tests/functions/test_n_bins_func.py rename tests/{helpers => functions}/test_norm_func.py (100%) rename tests/{helpers => functions}/test_normalise_func.py (100%) rename tests/{helpers => functions}/test_perturb_func.py (100%) rename tests/{helpers => functions}/test_pytorch_model.py (100%) rename tests/{helpers => functions}/test_similarity_func.py (100%) diff --git a/quantus/functions/complexity_func.py b/quantus/functions/complexity_func.py index 94c905ed4..bfc8be28a 100644 --- a/quantus/functions/complexity_func.py +++ b/quantus/functions/complexity_func.py @@ -17,7 +17,7 @@ def entropy(a: np.array, x: np.array, **kwargs) -> float: Parameters ---------- a: np.ndarray - Array to calculate entropy on. + Array to calculate entropy on. One sample at a time. x: np.ndarray Array to compute shape. kwargs: optional @@ -47,7 +47,7 @@ def gini_coeffiient(a: np.array, x: np.array, **kwargs) -> float: Parameters ---------- a: np.ndarray - Array to calculate gini_coeffiient on. + Array to calculate gini_coeffiient on. One sample at a time. x: np.ndarray Array to compute shape. kwargs: optional @@ -80,7 +80,7 @@ def discrete_entropy(a: np.array, x: np.array, **kwargs) -> float: Parameters ---------- a: np.ndarray - Array to calculate entropy on. + Array to calculate entropy on. One sample at a time. x: np.ndarray Array to compute shape. kwargs: optional diff --git a/tests/functions/__init__.py b/tests/functions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/functions/test_complexity_func.py b/tests/functions/test_complexity_func.py new file mode 100644 index 000000000..67f06d60b --- /dev/null +++ b/tests/functions/test_complexity_func.py @@ -0,0 +1,31 @@ +import pytest +import numpy as np +from quantus.functions.complexity_func import entropy, gini_coeffiient, discrete_entropy + + +@pytest.fixture +def array_data(): + return np.random.rand(1, 32, 32), np.random.rand(1, 32, 32) + + +@pytest.mark.complexity_func +def test_entropy(array_data): + a, x = array_data + result = entropy(a, x) + assert isinstance(result, float), "Output should be a float." + + +@pytest.mark.complexity_func +def test_gini_coefficient(array_data): + a, x = array_data + result = gini_coeffiient(a, x) + assert isinstance(result, float), "Output should be a float." + assert 0 <= result <= 1, "Gini coefficient should be in the range [0, 1]." + + +@pytest.mark.complexity_func +@pytest.mark.parametrize("n_bins", [10, 50, 100]) +def test_discrete_entropy(array_data, n_bins): + a, x = array_data + result = discrete_entropy(a, x, n_bins=n_bins) + assert isinstance(result, float), "Output should be a float." diff --git a/tests/helpers/test_explanation_func.py b/tests/functions/test_explanation_func.py similarity index 100% rename from tests/helpers/test_explanation_func.py rename to tests/functions/test_explanation_func.py diff --git a/tests/helpers/test_loss_func.py b/tests/functions/test_loss_func.py similarity index 100% rename from tests/helpers/test_loss_func.py rename to tests/functions/test_loss_func.py diff --git a/tests/helpers/test_mosaic_func.py b/tests/functions/test_mosaic_func.py similarity index 100% rename from tests/helpers/test_mosaic_func.py rename to tests/functions/test_mosaic_func.py diff --git a/tests/functions/test_n_bins_func.py b/tests/functions/test_n_bins_func.py new file mode 100644 index 000000000..0ea7e98fa --- /dev/null +++ b/tests/functions/test_n_bins_func.py @@ -0,0 +1,28 @@ +import pytest +from pytest_lazyfixture import lazy_fixture + +from quantus.functions.n_bins_func import * +import numpy as np + + +@pytest.fixture +def batch_data(): + return np.random.rand(10, 32, 32, 3) + + +@pytest.mark.n_bins_func +@pytest.mark.parametrize( + "func,data", + [ + (freedman_diaconis_rule, lazy_fixture("batch_data")), + (scotts_rule, lazy_fixture("batch_data")), + (square_root_choice, lazy_fixture("batch_data")), + (sturges_formula, lazy_fixture("batch_data")), + (rice_rule, lazy_fixture("batch_data")), + ], +) +def test_n_bins_func(func, data: np.ndarray): + n_bins = func(data) + print(n_bins) + assert isinstance(n_bins, int), "Output should be an integer." + assert n_bins > 0, "Number of bins should be positive." diff --git a/tests/helpers/test_norm_func.py b/tests/functions/test_norm_func.py similarity index 100% rename from tests/helpers/test_norm_func.py rename to tests/functions/test_norm_func.py diff --git a/tests/helpers/test_normalise_func.py b/tests/functions/test_normalise_func.py similarity index 100% rename from tests/helpers/test_normalise_func.py rename to tests/functions/test_normalise_func.py diff --git a/tests/helpers/test_perturb_func.py b/tests/functions/test_perturb_func.py similarity index 100% rename from tests/helpers/test_perturb_func.py rename to tests/functions/test_perturb_func.py diff --git a/tests/helpers/test_pytorch_model.py b/tests/functions/test_pytorch_model.py similarity index 100% rename from tests/helpers/test_pytorch_model.py rename to tests/functions/test_pytorch_model.py diff --git a/tests/helpers/test_similarity_func.py b/tests/functions/test_similarity_func.py similarity index 100% rename from tests/helpers/test_similarity_func.py rename to tests/functions/test_similarity_func.py From 49954e95af78d5b1ab5c808cb3199f9f919859a7 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Fri, 24 Nov 2023 10:42:08 +0100 Subject: [PATCH 10/11] smaller fixes --- README.md | 1 + quantus/functions/complexity_func.py | 12 +++-- quantus/functions/n_bins_func.py | 75 +++++++++++++++++++++++---- quantus/helpers/utils.py | 20 +++---- quantus/metrics/randomisation/mprt.py | 7 ++- 5 files changed, 88 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 3d86969fb..5d1f4f03e 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ _Quantus is currently under active development so carefully note the Quantus rel ## News and Highlights! :rocket: +- New metrics added: [EfficientMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/efficient_mprt.py) and [SmoothMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/smooth_mprt.py) by [Hedström et al., (2023)](https://openreview.net/forum?id=vVpefYmnsG) - Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) - Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html) - Offers more than **30+ metrics in 6 categories** for XAI evaluation diff --git a/quantus/functions/complexity_func.py b/quantus/functions/complexity_func.py index bfc8be28a..2c49727bc 100644 --- a/quantus/functions/complexity_func.py +++ b/quantus/functions/complexity_func.py @@ -12,7 +12,7 @@ def entropy(a: np.array, x: np.array, **kwargs) -> float: """ - Calculate entropy. + Calculate entropy of a single array. Parameters ---------- @@ -42,12 +42,12 @@ def entropy(a: np.array, x: np.array, **kwargs) -> float: def gini_coeffiient(a: np.array, x: np.array, **kwargs) -> float: """ - Calculate Gini coefficient. + Calculate Gini coefficient of a single array. Parameters ---------- a: np.ndarray - Array to calculate gini_coeffiient on. One sample at a time. + Array to calculate gini_coeffiient on. x: np.ndarray Array to compute shape. kwargs: optional @@ -76,11 +76,13 @@ def gini_coeffiient(a: np.array, x: np.array, **kwargs) -> float: def discrete_entropy(a: np.array, x: np.array, **kwargs) -> float: """ - Calculate discrete entropy of explanations with n_bins equidistant spaced bins + Calculate discrete entropy of explanations with n_bins + equidistant spaced bins of a single array. + Parameters ---------- a: np.ndarray - Array to calculate entropy on. One sample at a time. + Array to calculate entropy on. x: np.ndarray Array to compute shape. kwargs: optional diff --git a/quantus/functions/n_bins_func.py b/quantus/functions/n_bins_func.py index 6679d6790..42319b892 100644 --- a/quantus/functions/n_bins_func.py +++ b/quantus/functions/n_bins_func.py @@ -10,8 +10,19 @@ import numpy as np -def freedman_diaconis_rule(a_batch: np.array) -> int: - """Freedman–Diaconis' rule.""" +def freedman_diaconis_rule(a_batch: np.ndarray) -> int: + """ + Freedman–Diaconis' rule to compute the number of bins. + + Parameters + ---------- + a_batch: np.ndarray + The batch of attributions to use in the calculation. + + Returns + ------- + integer + """ iqr = np.percentile(a_batch, 75) - np.percentile(a_batch, 25) n = a_batch[0].ndim @@ -27,8 +38,19 @@ def freedman_diaconis_rule(a_batch: np.array) -> int: return n_bins -def scotts_rule(a_batch: np.array) -> int: - """Scott's rule.""" +def scotts_rule(a_batch: np.ndarray) -> int: + """ + Scott's rule to compute the number of bins. + + Parameters + ---------- + a_batch: np.ndarray + The batch of attributions to use in the calculation. + + Returns + ------- + integer + """ std = np.std(a_batch) n = a_batch[0].ndim @@ -42,8 +64,19 @@ def scotts_rule(a_batch: np.array) -> int: return n_bins -def square_root_choice(a_batch: np.array) -> int: - """Square-root choice rule.""" +def square_root_choice(a_batch: np.ndarray) -> int: + """ + Square-root choice rule to compute the number of bins. + + Parameters + ---------- + a_batch: np.ndarray + The batch of attributions to use in the calculation. + + Returns + ------- + integer + """ n = a_batch[0].ndim n_bins = int(np.sqrt(n)) @@ -51,8 +84,19 @@ def square_root_choice(a_batch: np.array) -> int: return n_bins -def sturges_formula(a_batch: np.array) -> int: - """Sturges' formula.""" +def sturges_formula(a_batch: np.ndarray) -> int: + """ + Sturges' rule to compute the number of bins. + + Parameters + ---------- + a_batch: np.ndarray + The batch of attributions to use in the calculation. + + Returns + ------- + integer + """ n = a_batch[0].ndim n_bins = int(np.log2(n) + 1) @@ -60,8 +104,19 @@ def sturges_formula(a_batch: np.array) -> int: return n_bins -def rice_rule(a_batch: np.array) -> int: - """Rice Rule.""" +def rice_rule(a_batch: np.ndarray) -> int: + """ + Rice rule to compute the number of bins. + + Parameters + ---------- + a_batch: np.ndarray + The batch of attributions to use in the calculation. + + Returns + ------- + integer + """ n = a_batch[0].ndim n_bins = int(2 * np.power(n, 1 / 3)) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 784a9aa8d..b3ffd577d 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -227,14 +227,14 @@ def infer_channel_first(x: np.array) -> bool: Infer if the channels are first. Assumes: - For 1d input: + For 1D input: nr_channels < sequence_length - For 2d input: + For 2D input: nr_channels < img_width and nr_channels < img_height - For higher dimensional input an error is raised. For input in n_features x n_batches format True is returned (no - channel). + For higher dimensional input an error is raised. + For input in n_features x n_batches format True is returned (no channel). Parameters ---------- @@ -281,7 +281,7 @@ def infer_channel_first(x: np.array) -> bool: else: raise ValueError( - "Only batched 2D and 3D multi-channel input dimensions supported." + "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." ) @@ -310,7 +310,7 @@ def make_channel_first(x: np.array, channel_first: bool = False): else: raise ValueError( - "Only batched 2D and 3D multi-channel input dimensions supported." + "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." ) @@ -338,7 +338,7 @@ def make_channel_last(x: np.array, channel_first: bool = True): return np.moveaxis(x, -3, -1) else: raise ValueError( - "Only batched 2D and 3D multi-channel input dimensions supported." + "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." ) @@ -512,7 +512,7 @@ def create_patch_slice( if len(patch_size) == 1 and len(coords) != 1: patch_size = tuple(patch_size for _ in coords) elif patch_size.ndim != 1: - raise ValueError("patch_size has to be either a scalar or a 1d-sequence") + raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(coords): raise ValueError( "patch_size sequence length does not match coords length" @@ -557,7 +557,7 @@ def get_nr_patches( if len(patch_size) == 1 and len(shape) != 1: patch_size = tuple(patch_size for _ in shape) elif patch_size.ndim != 1: - raise ValueError("patch_size has to be either a scalar or a 1d-sequence") + raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(shape): raise ValueError( "patch_size sequence length does not match shape length" @@ -839,7 +839,7 @@ def expand_indices( Expands indices to fit array shape. Returns expanded indices. --> if indices are a sequence of ints, they are interpreted as indices to the flattened arr, and subsequently expanded - --> if indices contains only slices and 1d sequences for arr, everything is interpreted as slices + --> if indices contains only slices and 1D sequences for arr, everything is interpreted as slices --> if indices contains already expanded indices, they are returned as is Parameters diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py index edf418cfe..078a1504c 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/mprt.py @@ -554,9 +554,12 @@ def evaluate_batch(self, *args, **kwargs): @final class ModelParameterRandomisation(MPRT): def __init__(self, *args, **kwargs): + warnings.simplefilter("always", DeprecationWarning) warnings.warn( - "ModelParameterRandomisation has been renamed to MPRT and will be removed in future releases. " - "Please use MPRT instead. This change is effective from Quantus version 0.5.0.", + "ModelParameterRandomisation metric has been renamed to MPRT and will " + "be removed in future releases. Please call quantus.MPRT() instead. " + "This change is effective from Quantus version 0.5.0. Note: MPRT is " + "functionally identical to ModelParameterRandomisation and can be used in the same way.", DeprecationWarning, ) super().__init__(*args, **kwargs) From d21dec0791e5172eb9365771bf92ec7c10aa8067 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Fri, 24 Nov 2023 11:40:02 +0100 Subject: [PATCH 11/11] small docstring fixes --- quantus/metrics/base.py | 4 +- .../metrics/randomisation/efficient_mprt.py | 3 +- quantus/metrics/randomisation/mprt.py | 1 + quantus/metrics/randomisation/smooth_mprt.py | 70 ++----------------- 4 files changed, 8 insertions(+), 70 deletions(-) diff --git a/quantus/metrics/base.py b/quantus/metrics/base.py index 999520f1e..fef819788 100644 --- a/quantus/metrics/base.py +++ b/quantus/metrics/base.py @@ -885,7 +885,7 @@ def explain_batch( y_batch: np.ndarray, ) -> np.ndarray: """ - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: @@ -893,7 +893,7 @@ def explain_batch( - unwrap model (if ModelInterface instance was provided) - call explain_func - expand attribution channel - - (optionally) normalize a_batch + - (optionally) normalise a_batch - (optionally) take np.abs of a_batch diff --git a/quantus/metrics/randomisation/efficient_mprt.py b/quantus/metrics/randomisation/efficient_mprt.py index 1697a328f..737d04129 100644 --- a/quantus/metrics/randomisation/efficient_mprt.py +++ b/quantus/metrics/randomisation/efficient_mprt.py @@ -508,8 +508,6 @@ def evaluate_instance( Parameters ---------- - i: integer - The evaluation instance. model: ModelInterface A ModelInteface that is subject to explanation. x: np.ndarray @@ -578,6 +576,7 @@ def generate_explanations( ) -> Generator[np.ndarray, None, None]: """ Iterate over dataset in batches and generate explanations for complete dataset. + Parameters ---------- model: ModelInterface diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py index 078a1504c..17efd64e1 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/mprt.py @@ -519,6 +519,7 @@ def generate_explanations( ) -> Generator[np.ndarray, None, None]: """ Iterate over dataset in batches and generate explanations for complete dataset. + Parameters ---------- model: ModelInterface diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index 57e548aed..81b2efd4e 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -475,8 +475,6 @@ def evaluate_instance( Parameters ---------- - i: integer - The evaluation instance. model: ModelInterface A ModelInteface that is subject to explanation. x: np.ndarray @@ -602,15 +600,15 @@ def explain_smooth_batch( **kwargs, ) -> np.ndarray: """ - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - call model.shape_input (if ModelInterface instance was provided) - unwrap model (if ModelInterface instance was provided) - - call explain_func + - add noise to input and call explain_func via explain_smooth_batch_numpy - expand attribution channel - - (optionally) normalize a_batch + - (optionally) normalise a_batch - (optionally) take np.abs of a_batch Parameters @@ -670,7 +668,7 @@ def explain_smooth_batch_numpy( **kwargs, ) -> np.ndarray: """ - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) + Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: @@ -712,63 +710,3 @@ def explain_smooth_batch_numpy( a_batch_smooth += a_batch / self.nr_samples return a_batch_smooth - - def explain_smooth_batch_torch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - std: float, - **kwargs, - ) -> np.ndarray: - """ - Compute explanations, normalize and take absolute (if was configured so during metric initialization.) - This method should primarily be used if you need to generate additional explanation - in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. - It will do few things: - - call model.shape_input (if ModelInterface instance was provided) - - unwrap model (if ModelInterface instance was provided) - - call explain_func - - expand attribution channel - - Parameters - ------- - model: - A model that is subject to explanation. - x_batch: - A np.ndarray which contains the input data that are explained. - y_batch: - A np.ndarray which contains the output labels that are explained. - std : float - Standard deviation of the Gaussian noise. - kwargs: optional, dict - List of hyperparameters. - - Returns - ------- - a_batch: - Batch of explanations ready to be evaluated. - """ - if not isinstance(x_batch, torch.Tensor): - x_batch = torch.Tensor(x_batch).to(self.device) - - if not isinstance(y_batch, torch.Tensor): - y_batch = torch.as_tensor(y_batch).to(self.device) - - a_batch_smooth = torch.zeros_like(x_batch) - for n in range(self.nr_samples): - # the last epsilon is defined as zero to compute the true output, - # and have SmoothGrad w/ n_iter = 1 === gradient - if n == self.nr_samples - 1: - epsilon = torch.zeros_like(x_batch) - else: - epsilon = torch.randn_like(x_batch) * std - - a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs) - - if a_batch_smooth is None: - a_batch_smooth = a_batch / self.nr_samples - else: - a_batch_smooth += a_batch / self.nr_samples - - return a_batch_smooth