From 4b44e2b86af1dbf75a91a2de2af86550dae7127e Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Fri, 24 Nov 2023 18:30:56 +0100 Subject: [PATCH] inverse --- quantus/evaluation.py | 90 +++++-- quantus/metrics/inverse_estimation.py | 339 ++++++++++++++++++++++++++ 2 files changed, 408 insertions(+), 21 deletions(-) create mode 100644 quantus/metrics/inverse_estimation.py diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 102c0980..2ced0230 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -31,41 +31,89 @@ def evaluate( **kwargs, ) -> Optional[dict]: """ - A method to evaluate some explanation methods given some metrics. + Evaluate different explanation methods using specified metrics. Parameters ---------- - metrics: dict - A dictionary with intialised metrics. - xai_methods: dict, list - Pass the different explanation methods as: - 1) Dict[str, np.ndarray] where values are pre-calculcated attributions, or - 2) Dict[str, Dict] where the keys are the name of the Quantus build-in explanation methods, - and the values are the explain function keyword arguments as a dictionary, or - 3) Dict[str, Callable] where the keys are the name of explanation methods, - and the values a callable explanation function. - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + metrics : dict + A dictionary of initialized evaluation metrics. See quantus.AVAILABLE_METRICS. + Example: {'Robustness': quantus.MaxSensitivity(), 'Faithfulness': quantus.PixelFlipping()} + + xai_methods : dict + A dictionary specifying the explanation methods to evaluate, which can be structured in three ways: + + 1) Dict[str, Dict] for built-in Quantus methods: + + Example: + xai_methods = { + 'IntegratedGradients': { + 'n_steps': 10, + 'xai_lib': 'captum' + }, + 'Saliency': { + 'xai_lib': 'captum' + } + } + + - See quantus.AVAILABLE_XAI_METHODS_CAPTUM for supported captum methods. + - See quantus.AVAILABLE_XAI_METHODS_TF for supported tensorflow methods. + - See https://github.com/chr5tphr/zennit for supported zennit methods. + - Read more about the explanation function arguments here: + + + 2) Dict[str, Callable] for custom methods: + + Example: + xai_methods = { + 'custom_own_xai_method': custom_explain_function + } + + - Here, you can provide your own callable that mirrors the input and outputs of the quantus.explain() method. + + 3) Dict[str, np.ndarray] for pre-calculated attributions: + + Example: + xai_methods = { + 'LIME': precomputed_numpy_lime_attributions, + 'GradientShap': precomputed_numpy_shap_attributions + } + + - Note that some metrics, e.g., quantus.MaxSensitivity() within the robustness category, + requires the explanation function to be passed (as this is used in the evaluation logic). + + It is also possible to pass a combination of the above. + + model: Union[torch.nn.Module, tf.keras.Model] + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. + A np.ndarray containing the input data to be explained. + y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. + A np.ndarray containing the output labels corresponding to x_batch. + s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - agg_func: callable - Indicates how to aggregates scores e.g., pass np.mean. - progress: boolean - Indicates if progress should be printed to std, or not. + A np.ndarray containing segmentation masks that match the input. + + agg_func: Callable + Indicates how to aggregate scores, e.g., pass np.mean. + + progress: bool + Indicates if progress should be printed to standard output. + explain_func_kwargs: dict, optional Keyword arguments to be passed to explain_func on call. Pass None if using Dict[str, Dict] type for xai_methods. + call_kwargs: Dict[str, Dict] - Keyword arguments for the call of the metrics, keys are names for arg set and values are argument dictionaries. + Keyword arguments for the call of the metrics. Keys are names for argument sets, and values are argument dictionaries. + kwargs: optional Deprecated keyword arguments for the call of the metrics. + Returns ------- results: dict - A dictionary with the results. + A dictionary with the evaluation results. """ warn.check_kwargs(kwargs) diff --git a/quantus/metrics/inverse_estimation.py b/quantus/metrics/inverse_estimation.py new file mode 100644 index 00000000..a4e393d0 --- /dev/null +++ b/quantus/metrics/inverse_estimation.py @@ -0,0 +1,339 @@ +"""This module contains the implementation of the Pixel-Flipping metric.""" + +# This file is part of Quantus. +# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. +# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. +# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . +# Quantus project URL: . + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +from quantus.helpers import asserts +from quantus.helpers import plotting +from quantus.helpers import utils +from quantus.helpers import warn +from quantus.helpers.model.model_interface import ModelInterface +from quantus.functions.normalise_func import normalise_by_max +from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.metrics.base_perturbed import Metric +from quantus.helpers.enums import ( + ModelType, + DataType, + ScoreDirection, + EvaluationCategory, +) + + +class InverseEstimation(Metric): + """ + Implementation of Inverse Estimation experiment by Author et al., 2023. + + The basic idea is to .............. + + References: + 1) .......... + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Inverse-Estimation" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.LOWER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + metric_init: Metric, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Callable = None, + perturb_baseline: str = "black", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + default_plot_func: Optional[Callable] = None, + **kwargs, + ): + """ + Parameters + ---------- + features_in_step: integer + The size of the step, default=1. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=False. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="black". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + return_auc_per_sample: boolean + Indicates if an AUC score should be computed over the curve and returned. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + if metric_init.normalise_func is None: + normalise_func = normalise_by_max + + if metric_init.perturb_func is None: + perturb_func = baseline_replacement_by_indices + perturb_func = perturb_func + + if metric_init.perturb_func_kwargs is None: + perturb_func_kwargs = {} + perturb_func_kwargs["perturb_baseline"] = perturb_baseline + + if metric_init.default_plot_func is None: + # TODO. Create plot. + default_plot_func = plotting.plot_pixel_flipping_experiment + + abs = metric_init.abs + normalise = metric_init.normalise + return_aggregate = metric_init.return_aggregate + aggregate_func = metric_init.aggregate_func + display_progressbar = metric_init.display_progressbar + disable_warnings = metric_init.disable_warnings + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + perturb_func=perturb_func, + perturb_func_kwargs=perturb_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Asserts and warnings. + assert ( + not self.metric_init.return_aggregate + ), "Make sure to set return_aggregate=False when calling the inverse estimation." + # TODO. Update warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=("baseline value 'perturb_baseline'"), + citation=( + "Bach, Sebastian, et al. 'On pixel-wise explanations for non-linear classifier" + " decisions by layer - wise relevance propagation.' PloS one 10.7 (2015) " + "e0130140" + ), + ) + + self.metric_init = metric_init + self.all_evaluation_scores_meta = [] + self.all_evaluation_scores_meta_inverse = [] + + def __call__( + self, + model, + x_batch: np.array, + y_batch: np.array, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = True, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency} + """ + self.results = {} + # Run a normal evaluation round. + self.metric_init( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + **kwargs, + ) + assert len(self.metric_init.evaluation_scores) == len(x_batch), ( + "To run the inverse estimation, the number of evaluation scores" + " must match the number of instances in the batch." + ) + self.all_evaluation_scores_meta.extend(self.metric_init.evaluation_scores) + + # Empty the evaluation scores before re-scoring with the metric. + self.metric_init.evaluation_scores = [] + + # Run inverse faithfulness experiment. + # Sort here to make it more general. + a_batch_inv = -np.array(a_batch) + self.metric_init( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch_inv, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + **kwargs, + ) + self.all_evaluation_scores_meta_inverse.extend( + self.metric_init.evaluation_scores + ) + + # Compute the inverse, empty the evaluation scores again and overwrite with the inverse scores. + inv_scores = np.array(self.all_evaluation_scores_meta) - np.array( + self.all_evaluation_scores_meta_inverse + ) + self.metric_init.evaluation_scores = [] + self.evaluation_scores.extend(inv_scores) + + # TODO. If all_evaluation_scores is empty, overwrite with inverse scores for the those last samples (keep iterator). + # Or skip and throw a warning. + self.metric_init.all_evaluation_scores = [] + self.all_evaluation_scores.extend(inv_scores) + + return inv_scores + + def custom_postprocess( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: Optional[np.ndarray], + a_batch: Optional[np.ndarray], + s_batch: np.ndarray, + **kwargs, + ) -> None: + """ + Post-process the evaluation results. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model e.g., torchvision.models that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + + Returns + ------- + None + """ + # TODO. Implement aggregation method. + pass