From 64cf19c3d18cd76dbc279bee57b499475a00efd8 Mon Sep 17 00:00:00 2001 From: davor Date: Sun, 14 Jul 2024 22:14:42 +0200 Subject: [PATCH 01/11] added batch version for faithfulness correlation --- quantus/functions/perturb_func.py | 52 +++ quantus/functions/similarity_func.py | 14 +- quantus/helpers/utils.py | 101 ++++- quantus/metrics/faithfulness/__init__.py | 1 + .../faithfulness/faithfulness_correlation.py | 358 +++++++++++++++++- 5 files changed, 513 insertions(+), 13 deletions(-) diff --git a/quantus/functions/perturb_func.py b/quantus/functions/perturb_func.py index 4ab352991..7db8daa36 100644 --- a/quantus/functions/perturb_func.py +++ b/quantus/functions/perturb_func.py @@ -118,6 +118,58 @@ def baseline_replacement_by_indices( return arr_perturbed +def batch_baseline_replacement_by_indices( + arr: np.array, + indices: Tuple[slice, ...], + perturb_baseline: Union[float, int, str, np.array], + **kwargs, +) -> np.array: + """ + Replace indices in an array by a given baseline. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. Shape N x F, where N is batch size and F is number of features. + indices: np.ndarray + Indices of the array to perturb. Shape N x I, where N is batch size and I is the number of indices to perturb. + perturb_baseline: float, int, str, np.ndarray + The baseline values to replace arr at indices with. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + + # Get the baseline value. + baseline_value = get_baseline_value( + value=perturb_baseline, + arr=arr, + return_shape=tuple(indices.shape), + batched=True, + **kwargs, + ) + + # Perturb the array. + arr_perturbed[np.arange(batch_size)[:, None], indices] = baseline_value + + return arr_perturbed + + def baseline_replacement_by_shift( arr: np.array, indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 88d19a9a7..4e9af76dc 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -35,7 +35,7 @@ def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: return scipy.stats.spearmanr(a, b)[0] -def correlation_pearson(a: np.array, b: np.array, **kwargs) -> float: +def correlation_pearson(a: np.array, b: np.array, batched=False, **kwargs) -> float: """ Calculate Pearson correlation of two images (or explanations). @@ -45,6 +45,8 @@ def correlation_pearson(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. @@ -53,7 +55,13 @@ def correlation_pearson(a: np.array, b: np.array, **kwargs) -> float: float The similarity score. """ - return scipy.stats.pearsonr(a, b)[0] + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + return ( + scipy.stats.pearsonr(a, b, axis=1)[0] + if batched + else scipy.stats.pearsonr(a, b)[0] + ) def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: @@ -145,7 +153,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs + **kwargs, ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 506459792..6853702b4 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -60,11 +60,77 @@ def get_superpixel_segments(img: np.ndarray, segmentation_method: str) -> np.nda ) +def batch_get_baseline_value( + value: Union[str, np.array], + arr: np.ndarray, + return_shape: Tuple, + patch: Optional[np.ndarray] = None, + **kwargs, +) -> Union[np.array, float]: + """ + Get the baseline value to fill the array with, in the shape of return_shape. + + Parameters + ---------- + value: float, int, str, np.ndarray + Either the value (float, int) to fill the array with, a method (string) used to construct + baseline array ("mean", "uniform", "black", "white", "neighbourhood_mean" or + "neighbourhood_random_min_max"), or the array (np.array) to be returned. + arr: np.ndarray + BxCxHxW image array used to calculate baseline values, i.e. for "mean", "black" and "white" methods. + return_shape: tuple + BxCxHxW shape to be returned. + patch: np.ndarray, optional + BxCxHxW patch array to calculate baseline values. + Necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. + kwargs: optional + Keyword arguments. + + Returns + ------- + np.ndarray + Baseline array in return_shape. + + """ + + kwargs["return_shape"] = return_shape + if isinstance(value, (float, int)): + return np.full(return_shape, value) + elif isinstance(value, np.ndarray): + if value.ndim == 0: + return np.full(return_shape, value) + elif value.shape == return_shape: + return value + else: + raise ValueError( + "Shape {} of argument 'value' cannot be fitted to required shape {} of return value".format( + value.shape, return_shape + ) + ) + elif isinstance(value, str): + fill_dict = get_baseline_dict(arr, patch, **kwargs) + if value.lower() == "random": + raise ValueError( + "'random' as a choice for 'perturb_baseline' is deprecated and has been removed from " + "the current release. Please use 'uniform' instead and pass lower- and upper bounds to " + "kwargs as see fit (default values are set to 'uniform_low=0.0' and 'uniform_high=1.0' " + "which will replicate the results of 'random').\n" + ) + if value.lower() not in fill_dict: + raise ValueError( + f"Ensure that 'value'(string) is in {list(fill_dict.keys())}" + ) + return np.full(return_shape, fill_dict[value.lower()]) + else: + raise ValueError("Specify 'value' as a np.array, string, integer or float.") + + def get_baseline_value( value: Union[float, int, str, np.array], arr: np.ndarray, return_shape: Tuple, patch: Optional[np.ndarray] = None, + batched: bool = False, **kwargs, ) -> np.array: """ @@ -83,6 +149,8 @@ def get_baseline_value( patch: np.ndarray, optional CxWxH patch array to calculate baseline values. Necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. + batched: bool, + If True, the arr and patch are assumed to be of shape B x . The aggregations are done over the individual batch elements. kwargs: optional Keyword arguments. @@ -108,7 +176,7 @@ def get_baseline_value( ) ) elif isinstance(value, str): - fill_dict = get_baseline_dict(arr, patch, **kwargs) + fill_dict = get_baseline_dict(arr, patch, batched, **kwargs) if value.lower() == "random": raise ValueError( "'random' as a choice for 'perturb_baseline' is deprecated and has been removed from " @@ -120,13 +188,18 @@ def get_baseline_value( raise ValueError( f"Ensure that 'value'(string) is in {list(fill_dict.keys())}" ) - return np.full(return_shape, fill_dict[value.lower()]) + fill_value = fill_dict[value.lower()] + # Expand the second dimension if batched to enable broadcasting + if batched: + fill_value = fill_value[:, None] + return np.full(return_shape, fill_value) + else: raise ValueError("Specify 'value' as a np.array, string, integer or float.") def get_baseline_dict( - arr: np.ndarray, patch: Optional[np.ndarray] = None, **kwargs + arr: np.ndarray, patch: Optional[np.ndarray] = None, batched: bool = False, **kwargs ) -> dict: """ Make a dictionary of baseline approaches depending on the input x (or patch of input). @@ -138,6 +211,8 @@ def get_baseline_dict( patch: np.ndarray, optional CxWxH patch array to calculate baseline values, necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. + batched: bool + If True, the arr and patch are assumed to be of shape B x . The aggregations are done over the individual batch elements. kwargs: optional Keyword arguments.. @@ -146,20 +221,28 @@ def get_baseline_dict( fill_dict: dict Maps all available baseline methods to baseline values. """ + + # Aggregate over elements of batch + aggregation_axes = ( + tuple(range(1 if batched else 0, len(arr.shape))) + if not patch + else tuple(range(1 if batched else 0, len(patch.shape))) + ) fill_dict = { - "mean": float(arr.mean()), + "mean": arr.mean(axis=aggregation_axes), "uniform": np.random.uniform( low=kwargs.get("uniform_low", 0.0), high=kwargs.get("uniform_high", 1.0), size=kwargs["return_shape"], ), - "black": float(arr.min()), - "white": float(arr.max()), + "black": arr.min(axis=aggregation_axes), + "white": arr.max(axis=aggregation_axes), } if patch is not None: - fill_dict["neighbourhood_mean"] = (float(patch.mean()),) - fill_dict["neighbourhood_random_min_max"] = float( - np.random.uniform(low=patch.min(), high=patch.max()) + fill_dict["neighbourhood_mean"] = patch.mean(axis=aggregation_axes) + fill_dict["neighbourhood_random_min_max"] = np.random.uniform( + low=patch.min(axis=aggregation_axes), + high=patch.max(axis=aggregation_axes), ) return fill_dict diff --git a/quantus/metrics/faithfulness/__init__.py b/quantus/metrics/faithfulness/__init__.py index dcf83d6e2..fd78ae5ec 100644 --- a/quantus/metrics/faithfulness/__init__.py +++ b/quantus/metrics/faithfulness/__init__.py @@ -5,6 +5,7 @@ # Quantus project URL: . from quantus.metrics.faithfulness.faithfulness_correlation import ( + BatchFaithfulnessCorrelation, FaithfulnessCorrelation, ) from quantus.metrics.faithfulness.faithfulness_estimate import FaithfulnessEstimate diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 6e8e405d5..295b6693e 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -10,7 +10,10 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + baseline_replacement_by_indices, + batch_baseline_replacement_by_indices, +) from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -29,6 +32,359 @@ from typing_extensions import final +@final +class BatchFaithfulnessCorrelation(Metric[List[float]]): + """ + Implementation of faithfulness correlation by Bhatt et al., 2020. + + The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness + (or 'fidelity') with respect to the model behaviour. + + Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and + the average explanation attribution for only the subset of features are (linearly) correlated, taking the + average over multiple runs and test samples. The metric returns one float per input-attribution pair that + ranges between -1 and 1, where higher scores are better. + + For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline + or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified + test point and the average explanation attribution for only the subset of features is calculated. Results is + average over multiple runs and several test samples. + + References: + 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model + explanations." IJCAI (2020): 3016-3022. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Faithfulness Correlation" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.HIGHER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + similarity_func: Optional[Callable] = None, + nr_runs: int = 100, + subset_size: int = 224, + abs: bool = False, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Optional[Callable] = None, + perturb_baseline: str = "black", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = True, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input. + If None, the default value is used, default=correlation_pearson. + nr_runs: integer + The number of runs (for each input and explanation pair), default=100. + subset_size: integer + The size of subset, default=224. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=False. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="black". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if perturb_func is None: + perturb_func = batch_baseline_replacement_by_indices + + if similarity_func is None: + similarity_func = correlation_pearson + + self.similarity_func = similarity_func + self.nr_runs = nr_runs + self.subset_size = subset_size + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "baseline value 'perturb_baseline', size of subset |S| 'subset_size'" + " and the number of runs (for each input and explanation pair) " + "'nr_runs'" + ), + citation=( + "Bhatt, Umang, Adrian Weller, and José MF Moura. 'Evaluating and aggregating " + "feature-based model explanations.' arXiv preprint arXiv:2005.00631 (2020)" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + custom_batch: any + Any object that can be passed to the evaluation process. + Gives flexibility to the user to adapt for implementing their own metric. + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=custom_batch, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, + **kwargs, + ) + + def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + kwargs: + Unused. + + Returns + ------- + tuple + In addition to the x_batch, y_batch, a_batch, s_batch and custom_batch, + returning a custom preprocess batch (custom_preprocess_batch). + """ + # Asserts. + asserts.assert_value_smaller_than_input_size( + x=x_batch, value=self.subset_size, value_name="subset_size" + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[float]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + scores_batch: + The evaluation results. + """ + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Predict on input. + x_input = model.shape_input( + x_batch, x_batch.shape, channel_first=True, batched=True + ) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + pred_deltas = [] + att_sums = [] + + x_batch_shape = x_batch.shape + # For each test data point, execute a couple of runs. + for i_ix in range(self.nr_runs): + # Randomly mask by subset size. + a_ix = np.stack( + [ + np.random.choice(n_features, self.subset_size, replace=False) + for _ in range(batch_size) + ], + axis=0, + ) + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change( + x=x_element, x_perturbed=x_perturbed_element + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + + # Sum attributions of the random subset. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + + return similarity + + @final class FaithfulnessCorrelation(Metric[List[float]]): """ From 339136171fdad6d7156cacf9d6f23e8db4f8389d Mon Sep 17 00:00:00 2001 From: davor Date: Tue, 16 Jul 2024 08:56:29 +0200 Subject: [PATCH 02/11] finished batch implementations of the first 5 faithfulness metrics --- quantus/functions/similarity_func.py | 17 +- quantus/helpers/utils.py | 72 +--- quantus/metrics/faithfulness/__init__.py | 13 +- .../faithfulness/faithfulness_correlation.py | 2 +- .../faithfulness/faithfulness_estimate.py | 344 +++++++++++++++- quantus/metrics/faithfulness/monotonicity.py | 338 +++++++++++++++- .../faithfulness/monotonicity_correlation.py | 373 +++++++++++++++++- .../metrics/faithfulness/pixel_flipping.py | 342 +++++++++++++++- 8 files changed, 1419 insertions(+), 82 deletions(-) diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 4e9af76dc..da3c2f0d0 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -14,7 +14,7 @@ import skimage -def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: +def correlation_spearman(a: np.array, b: np.array, batched=False, **kwargs) -> float: """ Calculate Spearman rank of two images (or explanations). @@ -24,6 +24,8 @@ def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. @@ -32,6 +34,12 @@ def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: float The similarity score. """ + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + # Spearman correlation is not calculated row-wise like pearson. Instead it is calculated between each + # pair from BOTH a and b + correlation = scipy.stats.spearmanr(a, b, axis=1)[0][: len(a), len(a) :] + return np.diag(correlation) return scipy.stats.spearmanr(a, b)[0] @@ -57,11 +65,8 @@ def correlation_pearson(a: np.array, b: np.array, batched=False, **kwargs) -> fl """ if batched: assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" - return ( - scipy.stats.pearsonr(a, b, axis=1)[0] - if batched - else scipy.stats.pearsonr(a, b)[0] - ) + return scipy.stats.pearsonr(a, b, axis=1)[0] + return scipy.stats.pearsonr(a, b)[0] def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 6853702b4..fa156b2b3 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -60,71 +60,6 @@ def get_superpixel_segments(img: np.ndarray, segmentation_method: str) -> np.nda ) -def batch_get_baseline_value( - value: Union[str, np.array], - arr: np.ndarray, - return_shape: Tuple, - patch: Optional[np.ndarray] = None, - **kwargs, -) -> Union[np.array, float]: - """ - Get the baseline value to fill the array with, in the shape of return_shape. - - Parameters - ---------- - value: float, int, str, np.ndarray - Either the value (float, int) to fill the array with, a method (string) used to construct - baseline array ("mean", "uniform", "black", "white", "neighbourhood_mean" or - "neighbourhood_random_min_max"), or the array (np.array) to be returned. - arr: np.ndarray - BxCxHxW image array used to calculate baseline values, i.e. for "mean", "black" and "white" methods. - return_shape: tuple - BxCxHxW shape to be returned. - patch: np.ndarray, optional - BxCxHxW patch array to calculate baseline values. - Necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. - kwargs: optional - Keyword arguments. - - Returns - ------- - np.ndarray - Baseline array in return_shape. - - """ - - kwargs["return_shape"] = return_shape - if isinstance(value, (float, int)): - return np.full(return_shape, value) - elif isinstance(value, np.ndarray): - if value.ndim == 0: - return np.full(return_shape, value) - elif value.shape == return_shape: - return value - else: - raise ValueError( - "Shape {} of argument 'value' cannot be fitted to required shape {} of return value".format( - value.shape, return_shape - ) - ) - elif isinstance(value, str): - fill_dict = get_baseline_dict(arr, patch, **kwargs) - if value.lower() == "random": - raise ValueError( - "'random' as a choice for 'perturb_baseline' is deprecated and has been removed from " - "the current release. Please use 'uniform' instead and pass lower- and upper bounds to " - "kwargs as see fit (default values are set to 'uniform_low=0.0' and 'uniform_high=1.0' " - "which will replicate the results of 'random').\n" - ) - if value.lower() not in fill_dict: - raise ValueError( - f"Ensure that 'value'(string) is in {list(fill_dict.keys())}" - ) - return np.full(return_shape, fill_dict[value.lower()]) - else: - raise ValueError("Specify 'value' as a np.array, string, integer or float.") - - def get_baseline_value( value: Union[float, int, str, np.array], arr: np.ndarray, @@ -233,7 +168,8 @@ def get_baseline_dict( "uniform": np.random.uniform( low=kwargs.get("uniform_low", 0.0), high=kwargs.get("uniform_high", 1.0), - size=kwargs["return_shape"], + # Return only (batch_size, ) if batched to align with other fill_dict values + size=kwargs["return_shape"][0] if batched else kwargs["return_shape"], ), "black": arr.min(axis=aggregation_axes), "white": arr.max(axis=aggregation_axes), @@ -1082,7 +1018,7 @@ def offset_coordinates( return off_coords[valid], valid -def calculate_auc(values: np.array, dx: int = 1): +def calculate_auc(values: np.array, dx: int = 1, batched: bool = False): """ Calculate area under the curve using the composite trapezoidal rule. @@ -1098,6 +1034,8 @@ def calculate_auc(values: np.array, dx: int = 1): np.ndarray Definite integral of values. """ + if batched: + return np.trapz(values, dx=dx, axis=1) return np.trapz(np.array(values), dx=dx) diff --git a/quantus/metrics/faithfulness/__init__.py b/quantus/metrics/faithfulness/__init__.py index fd78ae5ec..505ac3bd8 100644 --- a/quantus/metrics/faithfulness/__init__.py +++ b/quantus/metrics/faithfulness/__init__.py @@ -8,14 +8,21 @@ BatchFaithfulnessCorrelation, FaithfulnessCorrelation, ) -from quantus.metrics.faithfulness.faithfulness_estimate import FaithfulnessEstimate +from quantus.metrics.faithfulness.faithfulness_estimate import ( + BatchFaithfulnessEstimate, + FaithfulnessEstimate, +) from quantus.metrics.faithfulness.infidelity import Infidelity from quantus.metrics.faithfulness.irof import IROF -from quantus.metrics.faithfulness.monotonicity import Monotonicity +from quantus.metrics.faithfulness.monotonicity import BatchMonotonicity, Monotonicity from quantus.metrics.faithfulness.monotonicity_correlation import ( + BatchMonotonicityCorrelation, MonotonicityCorrelation, ) -from quantus.metrics.faithfulness.pixel_flipping import PixelFlipping +from quantus.metrics.faithfulness.pixel_flipping import ( + BatchPixelFlipping, + PixelFlipping, +) from quantus.metrics.faithfulness.region_perturbation import RegionPerturbation from quantus.metrics.faithfulness.road import ROAD from quantus.metrics.faithfulness.selectivity import Selectivity diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 295b6693e..3b5238f76 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -382,7 +382,7 @@ def evaluate_batch( similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) - return similarity + return similarity.tolist() @final diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index dad5fdaaa..3d9adf88d 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -9,8 +9,12 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + baseline_replacement_by_indices, + batch_baseline_replacement_by_indices, +) from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -29,6 +33,344 @@ from typing_extensions import final +@final +class BatchFaithfulnessEstimate(Metric[List[float]]): + """ + Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. + + Computes the correlations of probability drops and the relevance scores on various points, + showing the aggregate statistics. + + References: + 1) David Alvarez-Melis and Tommi S. Jaakkola.: "Towards robust interpretability with self-explaining + neural networks." NeurIPS (2018): 7786-7795. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Faithfulness Estimate" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.HIGHER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + similarity_func: Optional[Callable] = None, + features_in_step: int = 1, + abs: bool = False, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Optional[Callable] = None, + perturb_baseline: str = "black", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input. + If None, the default value is used, default=correlation_spearman. + features_in_step: integer + The size of the step, default=1. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=False. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="black". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if similarity_func is None: + similarity_func = correlation_pearson + if perturb_func is None: + perturb_func = batch_baseline_replacement_by_indices + self.similarity_func = similarity_func + self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "baseline value 'perturb_baseline' and similarity function " + "'similarity_func'" + ), + citation=( + "Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability" + " with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018)" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = False, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + custom_batch: any + Any object that can be passed to the evaluation process. + Gives flexibility to the user to adapt for implementing their own metric. + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=custom_batch, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + model_predict_kwargs=model_predict_kwargs, + softmax=softmax, + device=device, + batch_size=batch_size, + **kwargs, + ) + + def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + kwargs: + Unused. + + Returns + ------- + tuple + In addition to the x_batch, y_batch, a_batch, s_batch and custom_batch, + returning a custom preprocess batch (custom_preprocess_batch). + """ + # Asserts. + asserts.assert_features_in_step( + features_in_step=self.features_in_step, + input_shape=x_batch.shape[2:], + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[float]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + scores_batch: + The evaluation results. + """ + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Predict on input. + x_input = model.shape_input( + x_batch, x_batch.shape, channel_first=True, batched=True + ) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + n_perturbations = math.ceil(n_features / self.features_in_step) + pred_deltas = [] + att_sums = [] + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index + * self.features_in_step : (perturbation_step_index + 1) + * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change( + x=x_element, x_perturbed=x_perturbed_element + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + + # Sum attributions of the random subset. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + + return similarity.tolist() + + @final class FaithfulnessEstimate(Metric[List[float]]): """ diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 0f3ca879a..d039d8799 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -9,8 +9,12 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + baseline_replacement_by_indices, + batch_baseline_replacement_by_indices, +) from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( DataType, @@ -28,6 +32,338 @@ from typing_extensions import final +@final +class BatchMonotonicity(Metric[List[float]]): + """ + Implementation of Monotonicity metric by Arya at el., 2019. + + Monotonicity tests if adding more positive evidence increases the probability + of classification in the specified class. + + It captures attributions' faithfulness by incrementally adding each attribute + in order of increasing importance and evaluating the effect on model performance. + As more features are added, the performance of the model is expected to increase + and thus result in monotonically increasing model performance. + + References: + 1) Vijay Arya et al.: "One explanation does not fit all: A toolkit and taxonomy of ai explainability + techniques." arXiv preprint arXiv:1909.03012 (2019). + 2) Ronny Luss et al.: "Generating contrastive explanations with monotonic attribute functions." + arXiv preprint arXiv:1905.12698 (2019). + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Monotonicity" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.HIGHER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + features_in_step: int = 1, + abs: bool = True, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Optional[Callable] = None, + perturb_baseline: str = "black", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + features_in_step: integer + The size of the step, default=1. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=True. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="black". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + if perturb_func is None: + perturb_func = batch_baseline_replacement_by_indices + + # Save metric-specific attributes. + self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "baseline value 'perturb_baseline', also, the monotonicity " + "constraint between your given model and explanation method should be assessed" + ), + citation=( + "Arya, Vijay, et al. 'One explanation does not fit all: A toolkit and taxonomy" + " of ai explainability techniques.' arXiv preprint arXiv:1909.03012 (2019)" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = True, + device: Optional[str] = None, + batch_size: int = 64, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, + **kwargs, + ) + + def custom_preprocess( + self, + x_batch: np.ndarray, + **kwargs, + ) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + kwargs: + Unused. + + Returns + ------- + None + """ + # Asserts. + asserts.assert_features_in_step( + features_in_step=self.features_in_step, + input_shape=x_batch.shape[2:], + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[float]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInterface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + scores_batch: + The evaluation results. + """ + # Prepare shapes. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a_batch, axis=1) + + n_perturbations = math.ceil(n_features / self.features_in_step) + preds = [] + + # Copy the input x but fill with baseline values. + baseline_value = utils.get_baseline_value( + value=self.perturb_func.keywords["perturb_baseline"], # type: ignore + arr=x_batch.reshape(batch_size, -1), + return_shape=( + batch_size, + n_features, + ), # TODO. Double-check this over using = (1,). + batched=True, + ) + x_baseline = np.full((batch_size, n_features), baseline_value).reshape( + *x_batch.shape + ) + + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index + * self.features_in_step : (perturbation_step_index + 1) + * self.features_in_step, + ] + x_baseline = self.perturb_func( + arr=x_baseline.reshape(batch_size, -1), + indices=a_ix, + ) + x_baseline = x_baseline.reshape(*x_batch.shape) + + # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). + x_input = model.shape_input( + x_baseline, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + preds.append(y_pred_perturb) + preds = np.stack(preds, axis=1) + + return np.all(np.diff(preds) >= 0, axis=1).tolist() + + @final class Monotonicity(Metric[List[float]]): """ diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 7efddf17e..9505968f0 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -10,8 +10,12 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + baseline_replacement_by_indices, + batch_baseline_replacement_by_indices, +) from quantus.functions.similarity_func import correlation_spearman from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -30,6 +34,371 @@ from typing_extensions import final +@final +class BatchMonotonicityCorrelation(Metric[List[float]]): + """ + Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. + + Monotonicity measures the (Spearman’s) correlation coefficient of the absolute values of the attributions + and the uncertainty in probability estimation. The paper argues that if attributions are not monotonic + then they are not providing the correct importance of the feature. + + References: + 1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model + interpretability." arXiv preprint arXiv:2007.07584 (2020). + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Monotonicity" + data_applicability = { + DataType.IMAGE, + DataType.TIMESERIES, + DataType.TABULAR, + } + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.HIGHER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + similarity_func: Optional[Callable] = None, + eps: float = 1e-5, + nr_samples: int = 100, + features_in_step: int = 1, + abs: bool = True, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Optional[Callable] = None, + perturb_baseline: str = "uniform", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + similarity_func: callable + Similarity function applied to compare input and perturbed input. + If None, the default value is used, default=correlation_spearman. + eps: float + Attributions threshold, default=1e-5. + nr_samples: integer + The number of samples to iterate over, default=100. + features_in_step: integer + The size of the step, default=1. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=True. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="uniform". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + # Save metric-specific attributes. + if similarity_func is None: + similarity_func = correlation_spearman + + if perturb_func is None: + perturb_func = batch_baseline_replacement_by_indices + + self.similarity_func = similarity_func + + self.eps = eps + self.nr_samples = nr_samples + self.features_in_step = features_in_step + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params=( + "baseline value 'perturb_baseline', threshold value 'eps' and number " + "of samples to iterate over 'nr_samples'" + ), + citation=( + "Nguyen, An-phi, and María Rodríguez Martínez. 'On quantitative aspects of " + "model interpretability.' arXiv preprint arXiv:2007.07584 (2020)" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = True, + device: Optional[str] = None, + batch_size: int = 64, + custom_batch: Optional[Any] = None, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + custom_batch: any + Any object that can be passed to the evaluation process. + Gives flexibility to the user to adapt for implementing their own metric. + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=custom_batch, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, + **kwargs, + ) + + def custom_preprocess( + self, + x_batch: np.ndarray, + **kwargs, + ) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + kwargs: + Unused. + + Returns + ------- + None + """ + # Asserts. + asserts.assert_features_in_step( + features_in_step=self.features_in_step, + input_shape=x_batch.shape[2:], + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[float]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A model that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + scores_batch: + The evaluation results. + """ + + # Predict on input x. + x_input = model.shape_input( + x_batch, x_batch.shape, channel_first=True, batched=True + ) + batch_size = x_batch.shape[0] + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + inv_pred = np.ones(batch_size) + inv_pred[np.abs(y_pred) >= self.eps] = 1.0 / np.abs(y_pred) + inv_pred = inv_pred**2 + + # Reshape attributions. + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a_batch, axis=1) + + n_perturbations = math.ceil(n_features / self.features_in_step) + atts = [] + vars = [] + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index + * self.features_in_step : (perturbation_step_index + 1) + * self.features_in_step, + ] + + y_pred_perturbs = [] + for _ in range(self.nr_samples): + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change( + x=x_element, x_perturbed=x_perturbed_element + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + y_pred_perturbs.append(y_pred_perturb) + y_pred_perturbs = np.stack(y_pred_perturbs, axis=1) + + vars.append( + np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred + ) + atts.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=1)) + vars = np.stack(vars, axis=1) + atts = np.stack(atts, axis=1) + + return self.similarity_func(a=atts, b=vars, batched=True).tolist() + + @final class MonotonicityCorrelation(Metric[List[float]]): """ @@ -307,7 +676,7 @@ def evaluate_instance( y_pred = float(model.predict(x_input)[:, y]) inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred ** 2 + inv_pred = inv_pred**2 # Reshape attributions. a = a.flatten() diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index a081b6ba7..cdc531c4c 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -9,8 +9,12 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + baseline_replacement_by_indices, + batch_baseline_replacement_by_indices, +) from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -28,6 +32,342 @@ from typing_extensions import final +@final +class BatchPixelFlipping(Metric[Union[float, List[float]]]): + """ + Implementation of Pixel-Flipping experiment by Bach et al., 2015. + + The basic idea is to compute a decomposition of a digit for a digit class + and then flip pixels with highly positive, highly negative scores or pixels + with scores close to zero and then to evaluate the impact of these flips + onto the prediction scores (mean prediction is calculated). + + References: + 1) Sebastian Bach et al.: "On pixel-wise explanations for non-linear classifier + decisions by layer-wise relevance propagation." PloS one 10.7 (2015): e0130140. + + Attributes: + - _name: The name of the metric. + - _data_applicability: The data types that the metric implementation currently supports. + - _models: The model types that this metric can work with. + - score_direction: How to interpret the scores, whether higher/ lower values are considered better. + - evaluation_category: What property/ explanation quality that this metric measures. + """ + + name = "Pixel-Flipping" + data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + model_applicability = {ModelType.TORCH, ModelType.TF} + score_direction = ScoreDirection.LOWER + evaluation_category = EvaluationCategory.FAITHFULNESS + + def __init__( + self, + features_in_step: int = 1, + abs: bool = False, + normalise: bool = True, + normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, + normalise_func_kwargs: Optional[Dict[str, Any]] = None, + perturb_func: Optional[Callable] = None, + perturb_baseline: str = "black", + perturb_func_kwargs: Optional[Dict[str, Any]] = None, + return_aggregate: bool = False, + aggregate_func: Optional[Callable] = None, + return_auc_per_sample: bool = False, + default_plot_func: Optional[Callable] = None, + disable_warnings: bool = False, + display_progressbar: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + features_in_step: integer + The size of the step, default=1. + abs: boolean + Indicates whether absolute operation is applied on the attribution, default=False. + normalise: boolean + Indicates whether normalise operation is applied on the attribution, default=True. + normalise_func: callable + Attribution normalisation function applied in case normalise=True. + If normalise_func=None, the default value is used, default=normalise_by_max. + normalise_func_kwargs: dict + Keyword arguments to be passed to normalise_func on call, default={}. + perturb_func: callable + Input perturbation function. If None, the default value is used, + default=baseline_replacement_by_indices. + perturb_baseline: string + Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", + default="black". + perturb_func_kwargs: dict + Keyword arguments to be passed to perturb_func, default={}. + return_aggregate: boolean + Indicates if an aggregated score should be computed over all instances. + aggregate_func: callable + Callable that aggregates the scores given an evaluation call. + return_auc_per_sample: boolean + Indicates if an AUC score should be computed over the curve and returned. + default_plot_func: callable + Callable that plots the metrics result. + disable_warnings: boolean + Indicates whether the warnings are printed, default=False. + display_progressbar: boolean + Indicates whether a tqdm-progress-bar is printed, default=False. + kwargs: optional + Keyword arguments. + """ + if default_plot_func is None: + default_plot_func = plotting.plot_pixel_flipping_experiment + + super().__init__( + abs=abs, + normalise=normalise, + normalise_func=normalise_func, + normalise_func_kwargs=normalise_func_kwargs, + return_aggregate=return_aggregate, + aggregate_func=aggregate_func, + default_plot_func=default_plot_func, + display_progressbar=display_progressbar, + disable_warnings=disable_warnings, + **kwargs, + ) + + if perturb_func is None: + perturb_func = batch_baseline_replacement_by_indices + + # Save metric-specific attributes. + self.features_in_step = features_in_step + self.return_auc_per_sample = return_auc_per_sample + self.perturb_func = make_perturb_func( + perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline + ) + + # Asserts and warnings. + if not self.disable_warnings: + warn.warn_parameterisation( + metric_name=self.__class__.__name__, + sensitive_params="baseline value 'perturb_baseline'", + citation=( + "Bach, Sebastian, et al. 'On pixel-wise explanations for non-linear classifier" + " decisions by layer - wise relevance propagation.' PloS one 10.7 (2015) " + "e0130140" + ), + ) + + def __call__( + self, + model, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: Optional[np.ndarray] = None, + s_batch: Optional[np.ndarray] = None, + channel_first: Optional[bool] = None, + explain_func: Optional[Callable] = None, + explain_func_kwargs: Optional[Dict] = None, + model_predict_kwargs: Optional[Dict] = None, + softmax: Optional[bool] = True, + device: Optional[str] = None, + batch_size: int = 64, + **kwargs, + ) -> List[float]: + """ + This implementation represents the main logic of the metric and makes the class object callable. + It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), + output labels (y_batch) and a torch or tensorflow model (model). + + Calls general_preprocess() with all relevant arguments, calls + () on each instance, and saves results to evaluation_scores. + Calls custom_postprocess() afterwards. Finally returns evaluation_scores. + + Parameters + ---------- + model: torch.nn.Module, tf.keras.Model + A torch or tensorflow model that is subject to explanation. + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + y_batch: np.ndarray + A np.ndarray which contains the output labels that are explained. + a_batch: np.ndarray, optional + A np.ndarray which contains pre-computed attributions i.e., explanations. + s_batch: np.ndarray, optional + A np.ndarray which contains segmentation masks that matches the input. + channel_first: boolean, optional + Indicates of the image dimensions are channel first, or channel last. + Inferred from the input shape if None. + explain_func: callable + Callable generating attributions. + explain_func_kwargs: dict, optional + Keyword arguments to be passed to explain_func on call. + model_predict_kwargs: dict, optional + Keyword arguments to be passed to the model's predict method. + softmax: boolean + Indicates whether to use softmax probabilities or logits in model prediction. + This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. + device: string + Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". + kwargs: optional + Keyword arguments. + + Returns + ------- + evaluation_scores: list + a list of Any with the evaluation scores of the concerned batch. + + Examples: + -------- + # Minimal imports. + >> import quantus + >> from quantus import LeNet + >> import torch + + # Enable GPU. + >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). + >> model = LeNet() + >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) + + # Load MNIST datasets and make loaders. + >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) + >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) + + # Load a batch of inputs and outputs to use for XAI evaluation. + >> x_batch, y_batch = iter(test_loader).next() + >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() + + # Generate Saliency attributions of the test set batch of the test set. + >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) + >> a_batch_saliency = a_batch_saliency.cpu().numpy() + + # Initialise the metric and evaluate explanations by calling the metric instance. + >> metric = Metric(abs=True, normalise=False) + >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) + """ + return super().__call__( + model=model, + x_batch=x_batch, + y_batch=y_batch, + a_batch=a_batch, + s_batch=s_batch, + custom_batch=None, + channel_first=channel_first, + explain_func=explain_func, + explain_func_kwargs=explain_func_kwargs, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + batch_size=batch_size, + **kwargs, + ) + + def custom_preprocess( + self, + x_batch: np.ndarray, + **kwargs, + ) -> None: + """ + Implementation of custom_preprocess_batch. + + Parameters + ---------- + x_batch: np.ndarray + A np.ndarray which contains the input data that are explained. + kwargs: + Unused. + + Returns + ------- + None + """ + # Asserts. + asserts.assert_features_in_step( + features_in_step=self.features_in_step, + input_shape=x_batch.shape[2:], + ) + + @property + def get_auc_score(self): + """Calculate the area under the curve (AUC) score for several test samples.""" + return np.mean( + [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] + ) + + def evaluate_batch( + self, + model: ModelInterface, + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[Union[float, List[float]]]: + """ + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. + + Parameters + ---------- + model: ModelInterface + A ModelInteface that is subject to explanation. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. + + Returns + ------- + scores_batch: + The evaluation results. + """ + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Prepare lists. + n_perturbations = math.ceil(n_features / self.features_in_step) + preds = [] + x_perturbed = x_batch.copy() + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index + * self.features_in_step : (perturbation_step_index + 1) + * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_perturbed.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change( + x=x_element, x_perturbed=x_perturbed_element + ) + + # Predict on perturbed input x. + x_input = model.shape_input( + x_perturbed, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + preds.append(y_pred_perturb) + preds = np.stack(preds, axis=1) + + if self.return_auc_per_sample: + return utils.calculate_auc(preds, batched=True).tolist() + + return preds.tolist() + + @final class PixelFlipping(Metric[Union[float, List[float]]]): """ From 2952fdf5c351ca859d2ca574ae8676bed86fc4f2 Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Sat, 27 Jul 2024 10:35:46 +0200 Subject: [PATCH 03/11] added tests for batching --- quantus/functions/perturb_func.py | 2 +- quantus/functions/similarity_func.py | 30 +++++++++----- tests/functions/test_perturb_func.py | 35 ++++++++++++++++ tests/metrics/test_faithfulness_metrics.py | 47 +++++++++++----------- 4 files changed, 81 insertions(+), 33 deletions(-) diff --git a/quantus/functions/perturb_func.py b/quantus/functions/perturb_func.py index 7db8daa36..f653feb25 100644 --- a/quantus/functions/perturb_func.py +++ b/quantus/functions/perturb_func.py @@ -120,7 +120,7 @@ def baseline_replacement_by_indices( def batch_baseline_replacement_by_indices( arr: np.array, - indices: Tuple[slice, ...], + indices: np.array, perturb_baseline: Union[float, int, str, np.array], **kwargs, ) -> np.array: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index da3c2f0d0..6c2b83eff 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -14,7 +14,9 @@ import skimage -def correlation_spearman(a: np.array, b: np.array, batched=False, **kwargs) -> float: +def correlation_spearman( + a: np.array, b: np.array, batched: bool = False, **kwargs +) -> Union[float, np.array]: """ Calculate Spearman rank of two images (or explanations). @@ -31,8 +33,8 @@ def correlation_spearman(a: np.array, b: np.array, batched=False, **kwargs) -> f Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ if batched: assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" @@ -43,7 +45,9 @@ def correlation_spearman(a: np.array, b: np.array, batched=False, **kwargs) -> f return scipy.stats.spearmanr(a, b)[0] -def correlation_pearson(a: np.array, b: np.array, batched=False, **kwargs) -> float: +def correlation_pearson( + a: np.array, b: np.array, batched: bool = False, **kwargs +) -> Union[float, np.array]: """ Calculate Pearson correlation of two images (or explanations). @@ -60,8 +64,8 @@ def correlation_pearson(a: np.array, b: np.array, batched=False, **kwargs) -> fl Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ if batched: assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" @@ -69,7 +73,9 @@ def correlation_pearson(a: np.array, b: np.array, batched=False, **kwargs) -> fl return scipy.stats.pearsonr(a, b)[0] -def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: +def correlation_kendall_tau( + a: np.array, b: np.array, batched: bool = False, **kwargs +) -> Union[float, np.array]: """ Calculate Kendall Tau correlation of two images (or explanations). @@ -79,14 +85,20 @@ def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + # No support for axis currently, so just iterating over the batch + return np.array([scipy.stats.kendalltau(a_i, b_i)[0] for a_i, b_i in zip(a, b)]) return scipy.stats.kendalltau(a, b)[0] diff --git a/tests/functions/test_perturb_func.py b/tests/functions/test_perturb_func.py index 0f2b87b67..4d76dea78 100644 --- a/tests/functions/test_perturb_func.py +++ b/tests/functions/test_perturb_func.py @@ -30,6 +30,11 @@ def input_zeros_2d_3ch_flattened(): return np.zeros(shape=(3, 224, 224)).flatten() +@pytest.fixture +def input_batch_zeros_2d_3ch_flattened(): + return np.zeros(shape=(2, 3, 224, 224)).reshape(2, -1) + + @pytest.fixture def input_uniform_2d_3ch_flattened(): return np.random.uniform(0, 0.1, size=(3, 224, 224)).flatten() @@ -161,6 +166,36 @@ def test_baseline_replacement_by_indices( ), f"Test failed.{out}" +@pytest.mark.perturb_func +@pytest.mark.parametrize( + "data,params,expected", + [ + ( + lazy_fixture("input_batch_zeros_2d_3ch_flattened"), + { + "indices": np.array([[0, 2], [2, 5]]), + "perturb_baseline": 1.0, + }, + 1, + ), + ], +) +def test_batch_baseline_replacement_by_indices( + data: np.ndarray, params: dict, expected: Union[float, dict, bool] +): + batch_size = data.shape[0] + # Output + out = batch_baseline_replacement_by_indices(arr=data, **params) + + # Indices + indices = params["indices"] + + if isinstance(expected, (int, float)): + assert np.all( + out[np.arange(batch_size)[:, None], indices] == expected + ), f"Test failed.{out}" + + @pytest.mark.perturb_func @pytest.mark.parametrize( "data,params,expected", diff --git a/tests/metrics/test_faithfulness_metrics.py b/tests/metrics/test_faithfulness_metrics.py index 775fdd58c..d8b8b690c 100644 --- a/tests/metrics/test_faithfulness_metrics.py +++ b/tests/metrics/test_faithfulness_metrics.py @@ -6,6 +6,7 @@ from quantus.functions.explanation_func import explain from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, baseline_replacement_by_indices, noisy_linear_imputation, ) @@ -39,7 +40,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -61,7 +62,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -84,7 +85,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -105,7 +106,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -126,7 +127,7 @@ lazy_fixture("load_mnist_images_tf"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -148,7 +149,7 @@ lazy_fixture("load_mnist_images_tf"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -169,7 +170,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -190,7 +191,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -209,7 +210,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -227,7 +228,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -305,7 +306,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -326,7 +327,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 196, "perturb_baseline": "uniform", "normalise": True, @@ -347,7 +348,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -369,7 +370,7 @@ def test_faithfulness_correlation( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "abs": True, @@ -391,7 +392,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -413,7 +414,7 @@ def test_faithfulness_correlation( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "uniform", "features_in_step": 10, "normalise": True, @@ -611,7 +612,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "black", "normalise": True, @@ -632,7 +633,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "white", "normalise": True, @@ -654,7 +655,7 @@ def test_iterative_removal_of_features( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "mean", "normalise": True, @@ -675,7 +676,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "black", "normalise": True, @@ -697,7 +698,7 @@ def test_iterative_removal_of_features( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "black", "features_in_step": 10, "normalise": True, @@ -980,7 +981,7 @@ def test_monotonicity_correlation( "init": { "features_in_step": 10, "normalise": False, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "disable_warnings": True, }, @@ -1017,7 +1018,7 @@ def test_monotonicity_correlation( "init": { "features_in_step": 10, "normalise": False, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "disable_warnings": True, }, From 4f445108fb3206a626b6ba3060d88469570c545e Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Sat, 27 Jul 2024 13:00:00 +0200 Subject: [PATCH 04/11] removing Batched classes --- quantus/metrics/faithfulness/__init__.py | 14 +- .../faithfulness/faithfulness_correlation.py | 375 +---------------- .../faithfulness/faithfulness_estimate.py | 364 +--------------- quantus/metrics/faithfulness/monotonicity.py | 361 +--------------- .../faithfulness/monotonicity_correlation.py | 396 +----------------- .../metrics/faithfulness/pixel_flipping.py | 366 +--------------- 6 files changed, 33 insertions(+), 1843 deletions(-) diff --git a/quantus/metrics/faithfulness/__init__.py b/quantus/metrics/faithfulness/__init__.py index 505ac3bd8..dcf83d6e2 100644 --- a/quantus/metrics/faithfulness/__init__.py +++ b/quantus/metrics/faithfulness/__init__.py @@ -5,24 +5,16 @@ # Quantus project URL: . from quantus.metrics.faithfulness.faithfulness_correlation import ( - BatchFaithfulnessCorrelation, FaithfulnessCorrelation, ) -from quantus.metrics.faithfulness.faithfulness_estimate import ( - BatchFaithfulnessEstimate, - FaithfulnessEstimate, -) +from quantus.metrics.faithfulness.faithfulness_estimate import FaithfulnessEstimate from quantus.metrics.faithfulness.infidelity import Infidelity from quantus.metrics.faithfulness.irof import IROF -from quantus.metrics.faithfulness.monotonicity import BatchMonotonicity, Monotonicity +from quantus.metrics.faithfulness.monotonicity import Monotonicity from quantus.metrics.faithfulness.monotonicity_correlation import ( - BatchMonotonicityCorrelation, MonotonicityCorrelation, ) -from quantus.metrics.faithfulness.pixel_flipping import ( - BatchPixelFlipping, - PixelFlipping, -) +from quantus.metrics.faithfulness.pixel_flipping import PixelFlipping from quantus.metrics.faithfulness.region_perturbation import RegionPerturbation from quantus.metrics.faithfulness.road import ROAD from quantus.metrics.faithfulness.selectivity import Selectivity diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 3b5238f76..805c342ce 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -10,10 +10,7 @@ import numpy as np -from quantus.functions.perturb_func import ( - baseline_replacement_by_indices, - batch_baseline_replacement_by_indices, -) +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -33,7 +30,7 @@ @final -class BatchFaithfulnessCorrelation(Metric[List[float]]): +class FaithfulnessCorrelation(Metric[List[float]]): """ Implementation of faithfulness correlation by Bhatt et al., 2020. @@ -331,6 +328,10 @@ def evaluate_batch( scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + # Flatten the attributions. batch_size = a_batch.shape[0] a_batch = a_batch.reshape(batch_size, -1) @@ -383,367 +384,3 @@ def evaluate_batch( similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) return similarity.tolist() - - -@final -class FaithfulnessCorrelation(Metric[List[float]]): - """ - Implementation of faithfulness correlation by Bhatt et al., 2020. - - The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness - (or 'fidelity') with respect to the model behaviour. - - Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and - the average explanation attribution for only the subset of features are (linearly) correlated, taking the - average over multiple runs and test samples. The metric returns one float per input-attribution pair that - ranges between -1 and 1, where higher scores are better. - - For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline - or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified - test point and the average explanation attribution for only the subset of features is calculated. Results is - average over multiple runs and several test samples. - - References: - 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model - explanations." IJCAI (2020): 3016-3022. - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Faithfulness Correlation" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.FAITHFULNESS - - def __init__( - self, - similarity_func: Optional[Callable] = None, - nr_runs: int = 100, - subset_size: int = 224, - abs: bool = False, - normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Optional[Callable] = None, - perturb_baseline: str = "black", - perturb_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = True, - aggregate_func: Optional[Callable] = None, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - similarity_func: callable - Similarity function applied to compare input and perturbed input. - If None, the default value is used, default=correlation_pearson. - nr_runs: integer - The number of runs (for each input and explanation pair), default=100. - subset_size: integer - The size of subset, default=224. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=False. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - perturb_func: callable - Input perturbation function. If None, the default value is used, - default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="black". - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - # Save metric-specific attributes. - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - if similarity_func is None: - similarity_func = correlation_pearson - - self.similarity_func = similarity_func - self.nr_runs = nr_runs - self.subset_size = subset_size - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) - - # Asserts and warnings. - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline', size of subset |S| 'subset_size'" - " and the number of runs (for each input and explanation pair) " - "'nr_runs'" - ), - citation=( - "Bhatt, Umang, Adrian Weller, and José MF Moura. 'Evaluating and aggregating " - "feature-based model explanations.' arXiv preprint arXiv:2005.00631 (2020)" - ), - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = False, - device: Optional[str] = None, - batch_size: int = 64, - custom_batch: Optional[Any] = None, - **kwargs, - ) -> List[float]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - custom_batch: any - Any object that can be passed to the evaluation process. - Gives flexibility to the user to adapt for implementing their own metric. - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) - """ - return super().__call__( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=custom_batch, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - batch_size=batch_size, - **kwargs, - ) - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Flatten the attributions. - a = a.flatten() - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - pred_deltas = [] - att_sums = [] - - # For each test data point, execute a couple of runs. - for i_ix in range(self.nr_runs): - # Randomly mask by subset size. - a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(float(y_pred - y_pred_perturb)) - - # Sum attributions of the random subset. - att_sums.append(np.sum(a[a_ix])) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - - return similarity - - def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - kwargs: - Unused. - - Returns - ------- - tuple - In addition to the x_batch, y_batch, a_batch, s_batch and custom_batch, - returning a custom preprocess batch (custom_preprocess_batch). - """ - # Asserts. - asserts.assert_value_smaller_than_input_size( - x=x_batch, value=self.subset_size, value_name="subset_size" - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[float]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 3d9adf88d..36d73622d 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -11,10 +11,7 @@ import numpy as np import math -from quantus.functions.perturb_func import ( - baseline_replacement_by_indices, - batch_baseline_replacement_by_indices, -) +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -34,7 +31,7 @@ @final -class BatchFaithfulnessEstimate(Metric[List[float]]): +class FaithfulnessEstimate(Metric[List[float]]): """ Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. @@ -316,6 +313,10 @@ def evaluate_batch( scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + # Flatten the attributions. batch_size = a_batch.shape[0] a_batch = a_batch.reshape(batch_size, -1) @@ -369,356 +370,3 @@ def evaluate_batch( similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) return similarity.tolist() - - -@final -class FaithfulnessEstimate(Metric[List[float]]): - """ - Implementation of Faithfulness Estimate by Alvares-Melis at el., 2018a and 2018b. - - Computes the correlations of probability drops and the relevance scores on various points, - showing the aggregate statistics. - - References: - 1) David Alvarez-Melis and Tommi S. Jaakkola.: "Towards robust interpretability with self-explaining - neural networks." NeurIPS (2018): 7786-7795. - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Faithfulness Estimate" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.FAITHFULNESS - - def __init__( - self, - similarity_func: Optional[Callable] = None, - features_in_step: int = 1, - abs: bool = False, - normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Optional[Callable] = None, - perturb_baseline: str = "black", - perturb_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = False, - aggregate_func: Optional[Callable] = None, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - similarity_func: callable - Similarity function applied to compare input and perturbed input. - If None, the default value is used, default=correlation_spearman. - features_in_step: integer - The size of the step, default=1. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=False. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - perturb_func: callable - Input perturbation function. If None, the default value is used, - default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="black". - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - # Save metric-specific attributes. - if similarity_func is None: - similarity_func = correlation_pearson - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - self.similarity_func = similarity_func - self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) - - # Asserts and warnings. - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline' and similarity function " - "'similarity_func'" - ), - citation=( - "Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability" - " with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018)" - ), - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = False, - device: Optional[str] = None, - batch_size: int = 64, - custom_batch: Optional[Any] = None, - **kwargs, - ) -> List[float]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - custom_batch: any - Any object that can be passed to the evaluation process. - Gives flexibility to the user to adapt for implementing their own metric. - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) - """ - return super().__call__( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=custom_batch, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - model_predict_kwargs=model_predict_kwargs, - softmax=softmax, - device=device, - batch_size=batch_size, - **kwargs, - ) - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Flatten the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - pred_deltas = [None for _ in range(n_perturbations)] - att_sums = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas[i_ix] = float(y_pred - y_pred_perturb) - - # Sum attributions. - att_sums[i_ix] = np.sum(a[a_ix]) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - return similarity - - def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - kwargs: - Unused. - - Returns - ------- - tuple - In addition to the x_batch, y_batch, a_batch, s_batch and custom_batch, - returning a custom preprocess batch (custom_preprocess_batch). - """ - # Asserts. - asserts.assert_features_in_step( - features_in_step=self.features_in_step, - input_shape=x_batch.shape[2:], - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[float]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index d039d8799..9263a41e3 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -11,10 +11,7 @@ import numpy as np import math -from quantus.functions.perturb_func import ( - baseline_replacement_by_indices, - batch_baseline_replacement_by_indices, -) +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( DataType, @@ -33,7 +30,7 @@ @final -class BatchMonotonicity(Metric[List[float]]): +class Monotonicity(Metric[List[float]]): """ Implementation of Monotonicity metric by Arya at el., 2019. @@ -314,7 +311,10 @@ def evaluate_batch( scores_batch: The evaluation results. """ - # Prepare shapes. + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + batch_size = a_batch.shape[0] a_batch = a_batch.reshape(batch_size, -1) n_features = a_batch.shape[-1] @@ -362,352 +362,3 @@ def evaluate_batch( preds = np.stack(preds, axis=1) return np.all(np.diff(preds) >= 0, axis=1).tolist() - - -@final -class Monotonicity(Metric[List[float]]): - """ - Implementation of Monotonicity metric by Arya at el., 2019. - - Monotonicity tests if adding more positive evidence increases the probability - of classification in the specified class. - - It captures attributions' faithfulness by incrementally adding each attribute - in order of increasing importance and evaluating the effect on model performance. - As more features are added, the performance of the model is expected to increase - and thus result in monotonically increasing model performance. - - References: - 1) Vijay Arya et al.: "One explanation does not fit all: A toolkit and taxonomy of ai explainability - techniques." arXiv preprint arXiv:1909.03012 (2019). - 2) Ronny Luss et al.: "Generating contrastive explanations with monotonic attribute functions." - arXiv preprint arXiv:1905.12698 (2019). - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Monotonicity" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.FAITHFULNESS - - def __init__( - self, - features_in_step: int = 1, - abs: bool = True, - normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Optional[Callable] = None, - perturb_baseline: str = "black", - perturb_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = False, - aggregate_func: Optional[Callable] = None, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - features_in_step: integer - The size of the step, default=1. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=True. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - perturb_func: callable - Input perturbation function. If None, the default value is used, - default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="black". - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - # Save metric-specific attributes. - self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) - - # Asserts and warnings. - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline', also, the monotonicity " - "constraint between your given model and explanation method should be assessed" - ), - citation=( - "Arya, Vijay, et al. 'One explanation does not fit all: A toolkit and taxonomy" - " of ai explainability techniques.' arXiv preprint arXiv:1909.03012 (2019)" - ), - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = True, - device: Optional[str] = None, - batch_size: int = 64, - **kwargs, - ) -> List[float]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) - """ - return super().__call__( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=None, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - batch_size=batch_size, - **kwargs, - ) - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Prepare shapes. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - - # Copy the input x but fill with baseline values. - baseline_value = utils.get_baseline_value( - value=self.perturb_func.keywords["perturb_baseline"], # type: ignore - arr=x, - return_shape=x.shape, # TODO. Double-check this over using = (1,). - ) - x_baseline = np.full(x.shape, baseline_value) - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_baseline = self.perturb_func( - arr=x_baseline, - indices=a_ix, - indexed_axes=self.a_axes, - ) - - # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - return np.all(np.diff(preds) >= 0) - - def custom_preprocess( - self, - x_batch: np.ndarray, - **kwargs, - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - kwargs: - Unused. - - Returns - ------- - None - """ - # Asserts. - asserts.assert_features_in_step( - features_in_step=self.features_in_step, - input_shape=x_batch.shape[2:], - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[float]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 9505968f0..d14e8f2ad 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -12,10 +12,7 @@ import numpy as np import math -from quantus.functions.perturb_func import ( - baseline_replacement_by_indices, - batch_baseline_replacement_by_indices, -) +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_spearman from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -35,7 +32,7 @@ @final -class BatchMonotonicityCorrelation(Metric[List[float]]): +class MonotonicityCorrelation(Metric[List[float]]): """ Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. @@ -347,6 +344,10 @@ def evaluate_batch( inv_pred[np.abs(y_pred) >= self.eps] = 1.0 / np.abs(y_pred) inv_pred = inv_pred**2 + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + # Reshape attributions. a_batch = a_batch.reshape(batch_size, -1) n_features = a_batch.shape[-1] @@ -397,388 +398,3 @@ def evaluate_batch( atts = np.stack(atts, axis=1) return self.similarity_func(a=atts, b=vars, batched=True).tolist() - - -@final -class MonotonicityCorrelation(Metric[List[float]]): - """ - Implementation of Monotonicity Correlation metric by Nguyen at el., 2020. - - Monotonicity measures the (Spearman’s) correlation coefficient of the absolute values of the attributions - and the uncertainty in probability estimation. The paper argues that if attributions are not monotonic - then they are not providing the correct importance of the feature. - - References: - 1) An-phi Nguyen and María Rodríguez Martínez.: "On quantitative aspects of model - interpretability." arXiv preprint arXiv:2007.07584 (2020). - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Monotonicity" - data_applicability = { - DataType.IMAGE, - DataType.TIMESERIES, - DataType.TABULAR, - } - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.HIGHER - evaluation_category = EvaluationCategory.FAITHFULNESS - - def __init__( - self, - similarity_func: Optional[Callable] = None, - eps: float = 1e-5, - nr_samples: int = 100, - features_in_step: int = 1, - abs: bool = True, - normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Optional[Callable] = None, - perturb_baseline: str = "uniform", - perturb_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = False, - aggregate_func: Optional[Callable] = None, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - similarity_func: callable - Similarity function applied to compare input and perturbed input. - If None, the default value is used, default=correlation_spearman. - eps: float - Attributions threshold, default=1e-5. - nr_samples: integer - The number of samples to iterate over, default=100. - features_in_step: integer - The size of the step, default=1. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=True. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - perturb_func: callable - Input perturbation function. If None, the default value is used, - default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="uniform". - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - # Save metric-specific attributes. - if similarity_func is None: - similarity_func = correlation_spearman - - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - self.similarity_func = similarity_func - - self.eps = eps - self.nr_samples = nr_samples - self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) - - # Asserts and warnings. - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline', threshold value 'eps' and number " - "of samples to iterate over 'nr_samples'" - ), - citation=( - "Nguyen, An-phi, and María Rodríguez Martínez. 'On quantitative aspects of " - "model interpretability.' arXiv preprint arXiv:2007.07584 (2020)" - ), - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = True, - device: Optional[str] = None, - batch_size: int = 64, - custom_batch: Optional[Any] = None, - **kwargs, - ) -> List[float]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - custom_batch: any - Any object that can be passed to the evaluation process. - Gives flexibility to the user to adapt for implementing their own metric. - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) - """ - return super().__call__( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=custom_batch, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - batch_size=batch_size, - **kwargs, - ) - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Predict on input x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred**2 - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - atts = [None for _ in range(n_perturbations)] - vars = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - - y_pred_perturbs = [] - - for s_ix in range(self.nr_samples): - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - y_pred_perturbs.append(y_pred_perturb) - - vars[i_ix] = float( - np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) * inv_pred - ) - atts[i_ix] = float(sum(a[a_ix])) - - return self.similarity_func(a=atts, b=vars) - - def custom_preprocess( - self, - x_batch: np.ndarray, - **kwargs, - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - kwargs: - Unused. - - Returns - ------- - None - """ - # Asserts. - asserts.assert_features_in_step( - features_in_step=self.features_in_step, - input_shape=x_batch.shape[2:], - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[float]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A model that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - - # Evaluate explanations. - return [ - self.evaluate_instance( - model=model, - x=x, - y=y, - a=a, - ) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index cdc531c4c..0179cb415 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -11,10 +11,7 @@ import numpy as np import math -from quantus.functions.perturb_func import ( - baseline_replacement_by_indices, - batch_baseline_replacement_by_indices, -) +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -33,7 +30,7 @@ @final -class BatchPixelFlipping(Metric[Union[float, List[float]]]): +class PixelFlipping(Metric[Union[float, List[float]]]): """ Implementation of Pixel-Flipping experiment by Bach et al., 2015. @@ -321,6 +318,10 @@ def evaluate_batch( scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + # Flatten the attributions. batch_size = a_batch.shape[0] a_batch = a_batch.reshape(batch_size, -1) @@ -366,358 +367,3 @@ def evaluate_batch( return utils.calculate_auc(preds, batched=True).tolist() return preds.tolist() - - -@final -class PixelFlipping(Metric[Union[float, List[float]]]): - """ - Implementation of Pixel-Flipping experiment by Bach et al., 2015. - - The basic idea is to compute a decomposition of a digit for a digit class - and then flip pixels with highly positive, highly negative scores or pixels - with scores close to zero and then to evaluate the impact of these flips - onto the prediction scores (mean prediction is calculated). - - References: - 1) Sebastian Bach et al.: "On pixel-wise explanations for non-linear classifier - decisions by layer-wise relevance propagation." PloS one 10.7 (2015): e0130140. - - Attributes: - - _name: The name of the metric. - - _data_applicability: The data types that the metric implementation currently supports. - - _models: The model types that this metric can work with. - - score_direction: How to interpret the scores, whether higher/ lower values are considered better. - - evaluation_category: What property/ explanation quality that this metric measures. - """ - - name = "Pixel-Flipping" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} - model_applicability = {ModelType.TORCH, ModelType.TF} - score_direction = ScoreDirection.LOWER - evaluation_category = EvaluationCategory.FAITHFULNESS - - def __init__( - self, - features_in_step: int = 1, - abs: bool = False, - normalise: bool = True, - normalise_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, - normalise_func_kwargs: Optional[Dict[str, Any]] = None, - perturb_func: Optional[Callable] = None, - perturb_baseline: str = "black", - perturb_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: bool = False, - aggregate_func: Optional[Callable] = None, - return_auc_per_sample: bool = False, - default_plot_func: Optional[Callable] = None, - disable_warnings: bool = False, - display_progressbar: bool = False, - **kwargs, - ): - """ - Parameters - ---------- - features_in_step: integer - The size of the step, default=1. - abs: boolean - Indicates whether absolute operation is applied on the attribution, default=False. - normalise: boolean - Indicates whether normalise operation is applied on the attribution, default=True. - normalise_func: callable - Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. - normalise_func_kwargs: dict - Keyword arguments to be passed to normalise_func on call, default={}. - perturb_func: callable - Input perturbation function. If None, the default value is used, - default=baseline_replacement_by_indices. - perturb_baseline: string - Indicates the type of baseline: "mean", "random", "uniform", "black" or "white", - default="black". - perturb_func_kwargs: dict - Keyword arguments to be passed to perturb_func, default={}. - return_aggregate: boolean - Indicates if an aggregated score should be computed over all instances. - aggregate_func: callable - Callable that aggregates the scores given an evaluation call. - return_auc_per_sample: boolean - Indicates if an AUC score should be computed over the curve and returned. - default_plot_func: callable - Callable that plots the metrics result. - disable_warnings: boolean - Indicates whether the warnings are printed, default=False. - display_progressbar: boolean - Indicates whether a tqdm-progress-bar is printed, default=False. - kwargs: optional - Keyword arguments. - """ - if default_plot_func is None: - default_plot_func = plotting.plot_pixel_flipping_experiment - - super().__init__( - abs=abs, - normalise=normalise, - normalise_func=normalise_func, - normalise_func_kwargs=normalise_func_kwargs, - return_aggregate=return_aggregate, - aggregate_func=aggregate_func, - default_plot_func=default_plot_func, - display_progressbar=display_progressbar, - disable_warnings=disable_warnings, - **kwargs, - ) - - if perturb_func is None: - perturb_func = baseline_replacement_by_indices - - # Save metric-specific attributes. - self.features_in_step = features_in_step - self.return_auc_per_sample = return_auc_per_sample - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) - - # Asserts and warnings. - if not self.disable_warnings: - warn.warn_parameterisation( - metric_name=self.__class__.__name__, - sensitive_params="baseline value 'perturb_baseline'", - citation=( - "Bach, Sebastian, et al. 'On pixel-wise explanations for non-linear classifier" - " decisions by layer - wise relevance propagation.' PloS one 10.7 (2015) " - "e0130140" - ), - ) - - def __call__( - self, - model, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: Optional[np.ndarray] = None, - s_batch: Optional[np.ndarray] = None, - channel_first: Optional[bool] = None, - explain_func: Optional[Callable] = None, - explain_func_kwargs: Optional[Dict] = None, - model_predict_kwargs: Optional[Dict] = None, - softmax: Optional[bool] = True, - device: Optional[str] = None, - batch_size: int = 64, - **kwargs, - ) -> List[float]: - """ - This implementation represents the main logic of the metric and makes the class object callable. - It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), - output labels (y_batch) and a torch or tensorflow model (model). - - Calls general_preprocess() with all relevant arguments, calls - () on each instance, and saves results to evaluation_scores. - Calls custom_postprocess() afterwards. Finally returns evaluation_scores. - - Parameters - ---------- - model: torch.nn.Module, tf.keras.Model - A torch or tensorflow model that is subject to explanation. - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - y_batch: np.ndarray - A np.ndarray which contains the output labels that are explained. - a_batch: np.ndarray, optional - A np.ndarray which contains pre-computed attributions i.e., explanations. - s_batch: np.ndarray, optional - A np.ndarray which contains segmentation masks that matches the input. - channel_first: boolean, optional - Indicates of the image dimensions are channel first, or channel last. - Inferred from the input shape if None. - explain_func: callable - Callable generating attributions. - explain_func_kwargs: dict, optional - Keyword arguments to be passed to explain_func on call. - model_predict_kwargs: dict, optional - Keyword arguments to be passed to the model's predict method. - softmax: boolean - Indicates whether to use softmax probabilities or logits in model prediction. - This is used for this __call__ only and won't be saved as attribute. If None, self.softmax is used. - device: string - Indicated the device on which a torch.Tensor is or will be allocated: "cpu" or "gpu". - kwargs: optional - Keyword arguments. - - Returns - ------- - evaluation_scores: list - a list of Any with the evaluation scores of the concerned batch. - - Examples: - -------- - # Minimal imports. - >> import quantus - >> from quantus import LeNet - >> import torch - - # Enable GPU. - >> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). - >> model = LeNet() - >> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model")) - - # Load MNIST datasets and make loaders. - >> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True) - >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24) - - # Load a batch of inputs and outputs to use for XAI evaluation. - >> x_batch, y_batch = iter(test_loader).next() - >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() - - # Generate Saliency attributions of the test set batch of the test set. - >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) - >> a_batch_saliency = a_batch_saliency.cpu().numpy() - - # Initialise the metric and evaluate explanations by calling the metric instance. - >> metric = Metric(abs=True, normalise=False) - >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency) - """ - return super().__call__( - model=model, - x_batch=x_batch, - y_batch=y_batch, - a_batch=a_batch, - s_batch=s_batch, - custom_batch=None, - channel_first=channel_first, - explain_func=explain_func, - explain_func_kwargs=explain_func_kwargs, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - batch_size=batch_size, - **kwargs, - ) - - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> Union[float, List[float]]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - list - The evaluation results. - """ - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Prepare lists. - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - if self.return_auc_per_sample: - return float(utils.calculate_auc(preds)) - - return preds - - def custom_preprocess( - self, - x_batch: np.ndarray, - **kwargs, - ) -> None: - """ - Implementation of custom_preprocess_batch. - - Parameters - ---------- - x_batch: np.ndarray - A np.ndarray which contains the input data that are explained. - kwargs: - Unused. - - Returns - ------- - None - """ - # Asserts. - asserts.assert_features_in_step( - features_in_step=self.features_in_step, - input_shape=x_batch.shape[2:], - ) - - @property - def get_auc_score(self): - """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[Union[float, List[float]]]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] From c56104e1352046186ffa66b6bd5d9a79b18ab887 Mon Sep 17 00:00:00 2001 From: davor Date: Tue, 30 Jul 2024 10:44:12 +0200 Subject: [PATCH 05/11] simplifying trapz calculation and removing TODO --- quantus/helpers/utils.py | 5 ++--- quantus/metrics/faithfulness/monotonicity.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index fa156b2b3..36cf550de 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -1034,9 +1034,8 @@ def calculate_auc(values: np.array, dx: int = 1, batched: bool = False): np.ndarray Definite integral of values. """ - if batched: - return np.trapz(values, dx=dx, axis=1) - return np.trapz(np.array(values), dx=dx) + axis = 1 if batched else -1 + return np.trapz(np.array(values), dx=dx, axis=axis) T = TypeVar("T") diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 9263a41e3..d7eaf4402 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -332,7 +332,7 @@ def evaluate_batch( return_shape=( batch_size, n_features, - ), # TODO. Double-check this over using = (1,). + ), batched=True, ) x_baseline = np.full((batch_size, n_features), baseline_value).reshape( From 17fb93e62fb0759badae1badc852db179afa1a5b Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Thu, 19 Sep 2024 16:03:01 +0200 Subject: [PATCH 06/11] faithfulness metrics ready to test (got a lot of questions) --- quantus/functions/loss_func.py | 10 +- quantus/functions/norm_func.py | 18 +- quantus/functions/perturb_func.py | 181 ++++++++++++-- quantus/functions/similarity_func.py | 48 ++-- quantus/helpers/model/pytorch_model.py | 35 +-- quantus/helpers/model/tf_model.py | 35 +-- quantus/helpers/utils.py | 180 +++++++------- quantus/metrics/faithfulness/infidelity.py | 174 ++++++------- quantus/metrics/faithfulness/irof.py | 150 +++++------ .../faithfulness/region_perturbation.py | 233 ++++++++---------- quantus/metrics/faithfulness/road.py | 99 +++----- quantus/metrics/faithfulness/selectivity.py | 209 +++++++--------- quantus/metrics/faithfulness/sensitivity_n.py | 161 +++++------- quantus/metrics/faithfulness/sufficiency.py | 53 +--- tests/conftest.py | 47 ++-- tests/metrics/test_faithfulness_metrics.py | 38 +-- tests/metrics/test_localisation_metrics.py | 12 +- 17 files changed, 784 insertions(+), 899 deletions(-) diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index 69181e1f3..1ed646fd9 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -9,7 +9,7 @@ import numpy as np -def mse(a: np.array, b: np.array, **kwargs) -> float: +def mse(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float: """ Calculate Mean Squared Error between two images (or explanations). @@ -19,6 +19,8 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: Array to calculate MSE with. b: np.ndarray Array to calculate MSE with. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. normalise_mse: boolean @@ -31,10 +33,10 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: """ normalise = kwargs.get("normalise_mse", False) - + axis = -1 if batched else 0 if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) + return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=axis) # If no need to normalise, return (a-b)^2. - return np.average(((a - b) ** 2), axis=0) + return np.average(((a - b) ** 2), axis=axis) diff --git a/quantus/functions/norm_func.py b/quantus/functions/norm_func.py index 3d47bbdcb..94ef0b67a 100644 --- a/quantus/functions/norm_func.py +++ b/quantus/functions/norm_func.py @@ -16,15 +16,15 @@ def fro_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the Frobenius on. + The array to calculate the Frobenius on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'fro_norm' receives a 1D array." - return np.linalg.norm(a) + assert a.ndim == 1 or a.ndim == 2, "Check that 'fro_norm' receives a 1D or 2D array." + return np.linalg.norm(a, axis=-1) def l2_norm(a: np.array) -> float: @@ -34,15 +34,15 @@ def l2_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the L2 on + The array to calculate the L2 on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'l2_norm' receives a 1D array." - return np.linalg.norm(a) + assert a.ndim == 1 or a.ndim == 2, "Check that 'l2_norm' receives a 1D array." + return np.linalg.norm(a, axis=-1) def linf_norm(a: np.array) -> float: @@ -52,12 +52,12 @@ def linf_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the L-inf on. + The array to calculate the L-inf on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'linf_norm' receives a 1D array." - return np.linalg.norm(a, ord=np.inf) + assert a.ndim == 1 or a.ndim == 2, "Check that 'linf_norm' receives a 1D or 2D array." + return np.linalg.norm(a, ord=np.inf, axis=-1) diff --git a/quantus/functions/perturb_func.py b/quantus/functions/perturb_func.py index f653feb25..a586501c9 100644 --- a/quantus/functions/perturb_func.py +++ b/quantus/functions/perturb_func.py @@ -53,9 +53,7 @@ def perturb_batch( None, array """ if indices is not None: - assert arr.shape[0] == len( - indices - ), "arr and indices need same number of batches" + assert arr.shape[0] == len(indices), "arr and indices need same number of batches" if not inplace: arr = arr.copy() @@ -108,9 +106,7 @@ def baseline_replacement_by_indices( arr_perturbed = copy.copy(arr) # Get the baseline value. - baseline_value = get_baseline_value( - value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs - ) + baseline_value = get_baseline_value(value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs) # Perturb the array. arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes)) @@ -170,6 +166,51 @@ def batch_baseline_replacement_by_indices( return arr_perturbed +def baseline_replacement_by_mask( + arr: np.array, + mask: np.array, + perturb_baseline: Union[float, int, str, np.array], + **kwargs, +) -> np.array: + """ + Replace indices in an array by a given baseline. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. Shape is arbitrary. + mask: np.ndarray + Boolean mask of the array to perturb. Shape must be the same as the arr. + perturb_baseline: float, int, str, np.ndarray + The baseline values to replace arr at indices with. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert arr.shape == mask.shape, "The shape of arr must be the same as the mask shape" + + arr_perturbed = copy.copy(arr) + + # Get the baseline value. + baseline_value = get_baseline_value( + value=perturb_baseline, + arr=arr, + return_shape=tuple(arr.shape), + **kwargs, + ) + + # Perturb the array. + arr_perturbed = np.where(mask, baseline_value, arr_perturbed) + + return arr_perturbed + + def baseline_replacement_by_shift( arr: np.array, indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], @@ -206,9 +247,7 @@ def baseline_replacement_by_shift( arr_perturbed = copy.copy(arr) # Get the baseline value. - baseline_value = get_baseline_value( - value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs - ) + baseline_value = get_baseline_value(value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs) # Shift the input. arr_shifted = copy.copy(arr_perturbed) @@ -316,6 +355,56 @@ def gaussian_noise( return arr_perturbed +def batch_gaussian_noise( + arr: np.array, + indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], + perturb_mean: float = 0.0, + perturb_std: float = 0.01, + **kwargs, +) -> np.array: + """ + Add gaussian noise to the input at indices. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. + indices: int, sequence, tuple + Array-like, with a subset shape of arr. + perturb_mean (float): + The mean for gaussian noise. + perturb_std (float): + The standard deviation for gaussian noise. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + + # Sample the noise. + noise = np.random.normal(loc=perturb_mean, scale=perturb_std, size=arr.shape) + + # Perturb the array. + arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[ + np.arange(batch_size)[:, None], indices + ] + + return arr_perturbed + + def uniform_noise( arr: np.array, indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], @@ -355,9 +444,10 @@ def uniform_noise( if upper_bound is None: noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape) else: - assert upper_bound > lower_bound, ( - "Parameter 'upper_bound' needs to be larger than 'lower_bound', " - "but {} <= {}".format(upper_bound, lower_bound) + assert ( + upper_bound > lower_bound + ), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format( + upper_bound, lower_bound ) noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape) @@ -367,6 +457,66 @@ def uniform_noise( return arr_perturbed +def batch_uniform_noise( + arr: np.array, + indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], + lower_bound: float = 0.02, + upper_bound: Union[None, float] = None, + **kwargs, +) -> np.array: + """ + Add noise to the input at indices as sampled uniformly random from [-lower_bound, lower_bound]. + if upper_bound is None, and [lower_bound, upper_bound] otherwise. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. + indices: int, sequence, tuple + Array-like, with a subset shape of arr. + lower_bound: float + The lower bound for uniform sampling. + upper_bound: float, optional + The upper bound for uniform sampling. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + + # Sample the noise. + if upper_bound is None: + noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape) + else: + assert ( + upper_bound > lower_bound + ), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format( + upper_bound, lower_bound + ) + noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape) + + # Perturb the array. + arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[ + np.arange(batch_size)[:, None], indices + ] + + return arr_perturbed + + def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array: """ Rotate array by some given angle, assumes image type data and channel first layout. @@ -388,8 +538,7 @@ def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array: if arr.ndim != 3: raise ValueError( - "perturb func 'rotation' requires image-type data." - "Check that this perturb_func receives a 3D array." + "perturb func 'rotation' requires image-type data." "Check that this perturb_func receives a 3D array." ) matrix = cv2.getRotationMatrix2D( @@ -575,9 +724,7 @@ def noisy_linear_imputation( a[out_off_coord_ids, variable_ids] = weight # Reduce weight for invalid coordinates. - sum_neighbors[np.argwhere(valid == 0).flatten()] = ( - sum_neighbors[np.argwhere(valid == 0).flatten()] - weight - ) + sum_neighbors[np.argwhere(valid == 0).flatten()] = sum_neighbors[np.argwhere(valid == 0).flatten()] - weight a[np.arange(len(indices)), np.arange(len(indices))] = -sum_neighbors diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 6c2b83eff..da7f4bc18 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -12,11 +12,10 @@ import numpy as np import scipy import skimage +import sys -def correlation_spearman( - a: np.array, b: np.array, batched: bool = False, **kwargs -) -> Union[float, np.array]: +def correlation_spearman(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Spearman rank of two images (or explanations). @@ -45,9 +44,7 @@ def correlation_spearman( return scipy.stats.spearmanr(a, b)[0] -def correlation_pearson( - a: np.array, b: np.array, batched: bool = False, **kwargs -) -> Union[float, np.array]: +def correlation_pearson(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Pearson correlation of two images (or explanations). @@ -69,13 +66,14 @@ def correlation_pearson( """ if batched: assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" - return scipy.stats.pearsonr(a, b, axis=1)[0] + # No axis parameter in older versions + if sys.version_info >= (3, 10): + return scipy.stats.pearsonr(a, b, axis=1)[0] + return np.array([scipy.stats.pearsonr(aa, bb)[0] for aa, bb in zip(a, b)]) return scipy.stats.pearsonr(a, b)[0] -def correlation_kendall_tau( - a: np.array, b: np.array, batched: bool = False, **kwargs -) -> Union[float, np.array]: +def correlation_kendall_tau(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Kendall Tau correlation of two images (or explanations). @@ -120,7 +118,7 @@ def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float: float The similarity score. """ - return scipy.spatial.distance.euclidean(u=a, v=b) + return ((a - b) ** 2).sum(axis=-1) ** 0.5 def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: @@ -141,7 +139,7 @@ def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: float The similarity score. """ - return scipy.spatial.distance.cityblock(u=a, v=b) + return abs(a - b).sum(-1) def distance_chebyshev(a: np.array, b: np.array, **kwargs) -> float: @@ -201,9 +199,9 @@ def lipschitz_constant( d2 = kwargs.get("norm_denominator", distance_euclidean) if np.shape(a) == (): - return float(abs(a - b) / (d2(c, d) + eps)) + return abs(a - b) / (d2(c, d) + eps) else: - return float(d1(a, b) / (d2(a=c, b=d) + eps)) + return d1(a, b) / (d2(a=c, b=d) + eps) def abs_difference(a: np.array, b: np.array, **kwargs) -> float: @@ -269,7 +267,7 @@ def cosine(a: np.array, b: np.array, **kwargs) -> float: return scipy.spatial.distance.cosine(u=a, v=b) -def ssim(a: np.array, b: np.array, **kwargs) -> float: +def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float: """ Calculate Structural Similarity Index Measure of two images (or explanations). @@ -279,6 +277,8 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + Whether the arrays are batched. kwargs: optional Keyword arguments. @@ -287,13 +287,17 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float: float The similarity score. """ - max_point, min_point = np.max(np.concatenate([a, b])), np.min( - np.concatenate([a, b]) - ) - data_range = float(np.abs(max_point - min_point)) - return skimage.metrics.structural_similarity( - im1=a, im2=b, win_size=kwargs.get("win_size", None), data_range=data_range - ) + + def inner(aa: np.array, bb: np.array) -> float: + max_point, min_point = np.max(np.concatenate([aa, bb])), np.min(np.concatenate([aa, bb])) + data_range = float(np.abs(max_point - min_point)) + return skimage.metrics.structural_similarity( + im1=aa, im2=bb, win_size=kwargs.get("win_size", None), data_range=data_range + ) + + if batched: + return [inner(aa, bb) for aa, bb in zip(a, b)] + return inner(a, b) def difference(a: np.array, b: np.array, **kwargs) -> np.array: diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index 257cf0352..5c61a8335 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -273,6 +273,11 @@ def shape_input( # Expand first dimension if this is just a single instance. if not batched: x = x.reshape(1, *shape) + shape = (1, *shape) + + # If shape not the same, reshape the input + if shape != x.shape: + x = x.reshape(*shape) # Set channel order according to expected input of model. if self.channel_first: @@ -314,11 +319,7 @@ def get_random_layer_generator( original_parameters = self.state_dict() random_layer_model = deepcopy(self.model) - modules = [ - layer - for layer in random_layer_model.named_modules() - if (hasattr(layer[1], "reset_parameters")) - ] + modules = [layer for layer in random_layer_model.named_modules() if (hasattr(layer[1], "reset_parameters"))] if order == "top_down": modules = modules[::-1] @@ -408,13 +409,9 @@ def add_mean_shift_to_first_layer( for i in range(module[1].out_channels): if self.channel_first: - module[1].bias[i] = torch.nn.Parameter( - 2 * module[1].bias[i] - torch.unique(fw[i])[0] - ) + module[1].bias[i] = torch.nn.Parameter(2 * module[1].bias[i] - torch.unique(fw[i])[0]) else: - module[1].bias[i] = torch.nn.Parameter( - 2 * module[1].bias[i] - torch.unique(fw[..., i])[0] - ) + module[1].bias[i] = torch.nn.Parameter(2 * module[1].bias[i] - torch.unique(fw[..., i])[0]) return new_model @@ -457,9 +454,7 @@ def get_hidden_representations( # E.g., user can provide index -1, in order to get only representations of the last layer. # E.g., for 7 layers in total, this would correspond to positive index 6. - positive_layer_indices = [ - i if i >= 0 else num_layers + i for i in layer_indices - ] + positive_layer_indices = [i if i >= 0 else num_layers + i for i in layer_indices] if layer_names is None: layer_names = [] @@ -472,9 +467,7 @@ def is_layer_of_interest(layer_index: int, layer_name: str): # skip modules defined by subclassing API. hidden_layers = list( # type: ignore filter( - lambda layer: not isinstance( - layer[1], (self.model.__class__, torch.nn.Sequential) - ), + lambda layer: not isinstance(layer[1], (self.model.__class__, torch.nn.Sequential)), all_layers, ) ) @@ -509,13 +502,7 @@ def hook(module, module_in, module_out): @property def random_layer_generator_length(self) -> int: - return len( - [ - i - for i in self.model.named_modules() - if (hasattr(i[1], "reset_parameters")) - ] - ) + return len([i for i in self.model.named_modules() if (hasattr(i[1], "reset_parameters"))]) def safe_isinstance(obj: Any, class_path_str: Union[Iterable[str], str]) -> bool: diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index a7a5f388b..66162ca33 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -82,9 +82,7 @@ def _get_predict_kwargs(self, **kwargs: Dict[str, ...]) -> Dict[str, ...]: Filter out those, which are supported by Keras API. """ all_kwargs = {**self.model_predict_kwargs, **kwargs} - predict_kwargs = { - k: all_kwargs[k] for k in all_kwargs if k in self._available_predict_kwargs - } + predict_kwargs = {k: all_kwargs[k] for k in all_kwargs if k in self._available_predict_kwargs} return predict_kwargs @property @@ -214,6 +212,11 @@ def shape_input( # Expand first dimension if this is just a single instance. if not batched: x = x.reshape(1, *shape) + shape = (1, *shape) + + # If shape not the same, reshape the input + if shape != x.shape: + x = x.reshape(*shape) # Set channel order according to expected input of model. if self.channel_first: @@ -259,11 +262,7 @@ def get_random_layer_generator( original_parameters = self.state_dict() random_layer_model = clone_model(self.model) - layers = [ - _layer - for _layer in random_layer_model.layers - if len(_layer.get_weights()) > 0 - ] + layers = [_layer for _layer in random_layer_model.layers if len(_layer.get_weights()) > 0] if order == "top_down": layers = layers[::-1] @@ -277,9 +276,7 @@ def get_random_layer_generator( yield layer.name, random_layer_model @cachedmethod(operator.attrgetter("cache")) - def _build_hidden_representation_model( - self, layer_names: Tuple, layer_indices: Tuple - ) -> Model: + def _build_hidden_representation_model(self, layer_names: Tuple, layer_indices: Tuple) -> Model: """ Build a keras model, which outputs the internal representation of layers, specified in layer_names or layer_indices, default all. @@ -337,9 +334,7 @@ def add_mean_shift_to_first_layer( new_model.set_weights(original_parameters) module = new_model.layers[0] - tmp_model = Model( - inputs=[new_model.input], outputs=[new_model.layers[0].output] - ) + tmp_model = Model(inputs=[new_model.input], outputs=[new_model.layers[0].output]) delta = np.zeros(shape=shape) delta.fill(input_shift) @@ -395,9 +390,7 @@ def get_hidden_representations( # E.g., user can provide index -1, in order to get only representations of the last layer. # E.g., for 7 layers in total, this would correspond to positive index 6. - positive_layer_indices = [ - i if i >= 0 else num_layers + i for i in layer_indices - ] + positive_layer_indices = [i if i >= 0 else num_layers + i for i in layer_indices] if layer_names is None: layer_names = [] @@ -406,9 +399,7 @@ def get_hidden_representations( tuple(layer_names), tuple(positive_layer_indices) ) predict_kwargs = self._get_predict_kwargs(**kwargs) - internal_representation = hidden_representation_model.predict( - x, **predict_kwargs - ) + internal_representation = hidden_representation_model.predict(x, **predict_kwargs) input_batch_size = x.shape[0] # If we requested outputs only of 1 layer, keras will already return np.ndarray. @@ -416,9 +407,7 @@ def get_hidden_representations( if isinstance(internal_representation, np.ndarray): return internal_representation.reshape((input_batch_size, -1)) - internal_representation = [ - i.reshape((input_batch_size, -1)) for i in internal_representation - ] + internal_representation = [i.reshape((input_batch_size, -1)) for i in internal_representation] return np.hstack(internal_representation) @property diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 36cf550de..8880f239a 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -13,6 +13,8 @@ import numpy as np from skimage.segmentation import slic, felzenszwalb +import skimage +import math from quantus.helpers import asserts from quantus.helpers.model.model_interface import ModelInterface @@ -44,13 +46,10 @@ def get_superpixel_segments(img: np.ndarray, segmentation_method: str) -> np.nda if img.ndim != 3: raise ValueError( - "Make sure that x is 3 dimensional e.g., (3, 224, 224) to calculate super-pixels." - f" shape: {img.shape}" + "Make sure that x is 3 dimensional e.g., (3, 224, 224) to calculate super-pixels." f" shape: {img.shape}" ) if segmentation_method not in ["slic", "felzenszwalb"]: - raise ValueError( - "'segmentation_method' must be either 'slic' or 'felzenszwalb'." - ) + raise ValueError("'segmentation_method' must be either 'slic' or 'felzenszwalb'.") if segmentation_method == "slic": return slic(img, start_label=0) @@ -120,9 +119,7 @@ def get_baseline_value( "which will replicate the results of 'random').\n" ) if value.lower() not in fill_dict: - raise ValueError( - f"Ensure that 'value'(string) is in {list(fill_dict.keys())}" - ) + raise ValueError(f"Ensure that 'value'(string) is in {list(fill_dict.keys())}") fill_value = fill_dict[value.lower()] # Expand the second dimension if batched to enable broadcasting if batched: @@ -133,9 +130,7 @@ def get_baseline_value( raise ValueError("Specify 'value' as a np.array, string, integer or float.") -def get_baseline_dict( - arr: np.ndarray, patch: Optional[np.ndarray] = None, batched: bool = False, **kwargs -) -> dict: +def get_baseline_dict(arr: np.ndarray, patch: Optional[np.ndarray] = None, batched: bool = False, **kwargs) -> dict: """ Make a dictionary of baseline approaches depending on the input x (or patch of input). @@ -305,9 +300,7 @@ def infer_channel_first(x: np.array) -> bool: raise ValueError(err_msg) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def make_channel_first(x: np.array, channel_first: bool = False): @@ -337,9 +330,7 @@ def make_channel_first(x: np.array, channel_first: bool = False): return np.moveaxis(x, -1, -3) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def make_channel_last(x: np.array, channel_first: bool = True): @@ -368,9 +359,7 @@ def make_channel_last(x: np.array, channel_first: bool = True): elif len(np.shape(x)) == 4: return np.moveaxis(x, -3, -1) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def get_wrapped_model( @@ -448,9 +437,7 @@ def blur_at_indices( A version of arr that is blurred at indices. """ - assert kernel.ndim == len( - indexed_axes - ), "kernel should have as many dimensions as indexed_axes has elements." + assert kernel.ndim == len(indexed_axes), "kernel should have as many dimensions as indexed_axes has elements." # Pad array pad_width = [(0, 0) for _ in indexed_axes] @@ -475,9 +462,7 @@ def blur_at_indices( array_indices = np.array(array_indices) # Expand kernel dimensions - expanded_kernel = np.expand_dims( - kernel, tuple([i for i in range(arr.ndim) if i not in indexed_axes]) - ) + expanded_kernel = np.expand_dims(kernel, tuple([i for i in range(arr.ndim) if i not in indexed_axes])) # Iterate over indices, applying expanded kernel x_blur = copy.copy(x) @@ -489,9 +474,7 @@ def blur_at_indices( for ax, idx_ax in enumerate(expanded_idx): s = kernel.shape[ax] idx_ax = np.squeeze(idx_ax) - expanded_idx[ax] = slice( - idx_ax - (s // 2), idx_ax + s // 2 + 1 - (s % 2 == 0) - ) + expanded_idx[ax] = slice(idx_ax - (s // 2), idx_ax + s // 2 + 1 - (s % 2 == 0)) for dim in range(array_indices.ndim - 2): idx = [np.expand_dims(index, axis=index.ndim) for index in idx] @@ -512,9 +495,7 @@ def blur_at_indices( return _unpad_array(x_blur, pad_width, padded_axes=indexed_axes) -def create_patch_slice( - patch_size: Union[int, Sequence[int]], coords: Sequence[int] -) -> Tuple[slice, ...]: +def create_patch_slice(patch_size: Union[int, Sequence[int]], coords: Sequence[int]) -> Tuple[slice, ...]: """ Create a patch slice from patch size and coordinates. @@ -545,24 +526,48 @@ def create_patch_slice( elif patch_size.ndim != 1: raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(coords): - raise ValueError( - "patch_size sequence length does not match coords length" - f" (len(patch_size) != len(coords))" - ) + raise ValueError("patch_size sequence length does not match coords length" f" (len(patch_size) != len(coords))") # make sure that each element in tuple is integer patch_size = tuple(int(patch_size_dim) for patch_size_dim in patch_size) - patch_slice = [ - slice(coord, coord + patch_size_dim) - for coord, patch_size_dim in zip(coords, patch_size) - ] + patch_slice = [slice(coord, coord + patch_size_dim) for coord, patch_size_dim in zip(coords, patch_size)] return tuple(patch_slice) -def get_nr_patches( - patch_size: Union[int, Sequence[int]], shape: Tuple[int, ...], overlap: bool = False -) -> int: +def get_block_indices(x_batch: np.array, patch_size: int) -> np.array: + """ + Get blocks for a batch of images using a certain patch_size. + + Parameters + ---------- + x_batch: np.array + Batch of images, shape B x C x H x W + patch_size: int + Size of the patch. + + Yields + ------- + np.array + Yields blocks of image of a certain patch_size. + """ + batch_size = x_batch.shape[0] + x_indices = np.stack( + [np.arange(x.size) for x in x_batch], + axis=0, + ).reshape(*x_batch.shape) + blocks = skimage.util.view_as_blocks(x_indices, (*x_batch.shape[:2], patch_size, patch_size)) + blocks_h, blocks_w = blocks.shape[2:4] + block_indices = np.stack( + np.unravel_index(list(range(blocks_h * blocks_w)), shape=(blocks_h, blocks_w)), + axis=1, + ) + + for block_index in block_indices: + yield blocks[0, 0, block_index[0], block_index[1]].reshape(batch_size, -1) + + +def get_nr_patches(patch_size: Union[int, Sequence[int]], shape: Tuple[int, ...], overlap: bool = False) -> int: """ Get number of patches for given shape. @@ -590,10 +595,7 @@ def get_nr_patches( elif patch_size.ndim != 1: raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(shape): - raise ValueError( - "patch_size sequence length does not match shape length" - f" (len(patch_size) != len(shape))" - ) + raise ValueError("patch_size sequence length does not match shape length" f" (len(patch_size) != len(shape))") patch_size = tuple(patch_size) return np.prod(shape) // np.prod(patch_size) @@ -625,14 +627,10 @@ def _pad_array( Padded array. """ - assert ( - len(padded_axes) <= arr.ndim - ), "Cannot pad more axes than array has dimensions" + assert len(padded_axes) <= arr.ndim, "Cannot pad more axes than array has dimensions" if isinstance(pad_width, Sequence): - assert len(pad_width) == len( - padded_axes - ), "pad_width and padded_axes have different lengths" + assert len(pad_width) == len(padded_axes), "pad_width and padded_axes have different lengths" for p in pad_width: if isinstance(p, tuple): assert len(p) == 2, "Elements in pad_width need to have length 2" @@ -653,7 +651,6 @@ def _pad_array( ) else: pad_width_list.append(pad_width[[p for p in padded_axes].index(ax)]) - arr_pad = np.pad(arr, pad_width_list, mode=mode) return arr_pad @@ -683,14 +680,10 @@ def _unpad_array( """ - assert ( - len(padded_axes) <= arr.ndim - ), "Cannot unpad more axes than array has dimensions" + assert len(padded_axes) <= arr.ndim, "Cannot unpad more axes than array has dimensions" if isinstance(pad_width, Sequence): - assert len(pad_width) == len( - padded_axes - ), "pad_width and padded_axes have different lengths" + assert len(pad_width) == len(padded_axes), "pad_width and padded_axes have different lengths" for p in pad_width: if isinstance(p, tuple): assert len(p) == 2, "Elements in pad_width need to have length 2" @@ -721,6 +714,30 @@ def _unpad_array( return unpadded_arr +def get_padding_size(dim: int, patch_size: int) -> Tuple[int]: + """ + Calculate the padding size (optionally) needed for a patch_size. + + Parameters + ---------- + dim: int + Input dimension to pad (for example width or height of the image). + patch_size: int + Size of the patch for this dimension of input. + + Returns + ------- + Tuple[int] + A tuple of values passed to the utils._pad_array method for a particular dimension. + """ + modulo = dim % patch_size + if modulo == 0: + return (0, 0) + total_padding = patch_size - modulo + half_padding = total_padding / 2 + return (math.floor(half_padding), math.ceil(half_padding)) + + def expand_attribution_channel(a_batch: np.ndarray, x_batch: np.ndarray): """ Expand additional channel dimension(s) for attributions if needed. @@ -742,9 +759,7 @@ def expand_attribution_channel(a_batch: np.ndarray, x_batch: np.ndarray): f"a_batch and x_batch must have same number of batches ({a_batch.shape[0]} != {x_batch.shape[0]})" ) if a_batch.ndim > x_batch.ndim: - raise ValueError( - f"a must not have greater ndim than x ({a_batch.ndim} > {x_batch.ndim})" - ) + raise ValueError(f"a must not have greater ndim than x ({a_batch.ndim} > {x_batch.ndim})") if a_batch.ndim == x_batch.ndim: return a_batch @@ -783,9 +798,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence if a_batch.ndim > x_batch.ndim: raise ValueError( - "Attributions need to have <= dimensions than inputs, but {} > {}".format( - a_batch.ndim, x_batch.ndim - ) + "Attributions need to have <= dimensions than inputs, but {} > {}".format(a_batch.ndim, x_batch.ndim) ) # TODO: We currently assume here that the batch axis is not carried into the perturbation functions. @@ -800,17 +813,14 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence return np.array([]) x_subshapes = [ - [x_shape[i] for i in range(start, start + len(a_shape))] - for start in range(0, len(x_shape) - len(a_shape) + 1) + [x_shape[i] for i in range(start, start + len(a_shape))] for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " - "inputs were of shape {} and attributions of shape {}".format( - x_batch.shape, a_batch.shape - ) + "inputs were of shape {} and attributions of shape {}".format(x_batch.shape, a_batch.shape) ) elif x_subshapes.count(a_shape) > 1: @@ -833,17 +843,13 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence raise ValueError( "Attribution axes could not be inferred for inputs of " - "shape {} and attributions of shape {}".format( - x_batch.shape, a_batch.shape - ) + "shape {} and attributions of shape {}".format(x_batch.shape, a_batch.shape) ) raise ValueError( "Attribution dimensions are not unique subdimensions of inputs: " "inputs were of shape {} and attributions of shape {}." - "Please expand attribution dimensions for a unique solution".format( - x_batch.shape, a_batch.shape - ) + "Please expand attribution dimensions for a unique solution".format(x_batch.shape, a_batch.shape) ) else: # Infer attribution axes. @@ -917,20 +923,10 @@ def expand_indices( # Check if unraveling is needed. if np.all([isinstance(idx, int) for idx in expanded_indices]): - expanded_indices = np.unravel_index( - expanded_indices, tuple([arr.shape[i] for i in indexed_axes]) - ) - elif not np.all( - [ - isinstance(idx, np.ndarray) and idx.ndim == len(expanded_indices) - for idx in expanded_indices - ] - ): + expanded_indices = np.unravel_index(expanded_indices, tuple([arr.shape[i] for i in indexed_axes])) + elif not np.all([isinstance(idx, np.ndarray) and idx.ndim == len(expanded_indices) for idx in expanded_indices]): # Meshgrid sliced axes to account for correct slicing. Correct switched first two axes by meshgrid - expanded_indices = [ - np.swapaxes(idx, 0, 1) if idx.ndim > 1 else idx - for idx in np.meshgrid(*expanded_indices) - ] + expanded_indices = [np.swapaxes(idx, 0, 1) if idx.ndim > 1 else idx for idx in np.meshgrid(*expanded_indices)] # Handle case of 1D indices. if np.all([isinstance(idx, int) for idx in expanded_indices]): @@ -985,9 +981,7 @@ def get_leftover_shape(arr: np.array, axes: Sequence[int]) -> Tuple: return leftover_shape -def offset_coordinates( - indices: Union[list, Sequence[int], Tuple[Any]], offset: tuple, img_shape: tuple -): +def offset_coordinates(indices: Union[list, Sequence[int], Tuple[Any]], offset: tuple, img_shape: tuple): """ Checks if offset coordinates are within the image frame. Adapted from: https://github.com/tleemann/road_evaluation. diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 9a3b09703..b5c079973 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -11,7 +11,9 @@ import numpy as np from quantus.functions.loss_func import mse -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.helpers import utils, warn from quantus.helpers.enums import ( DataType, @@ -148,16 +150,14 @@ def __init__( self.loss_func = loss_func if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices if perturb_patch_sizes is None: perturb_patch_sizes = [4] self.perturb_patch_sizes = perturb_patch_sizes self.n_perturb_samples = n_perturb_samples self.nr_channels = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -283,100 +283,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - results = [] - - for _ in range(self.n_perturb_samples): - sub_results = [] - - for patch_size in self.perturb_patch_sizes: - pred_deltas = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - a_sums = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - x_perturbed = x.copy() - pad_width = patch_size - 1 - - for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): - for i_y, top_left_y in enumerate(range(0, x.shape[2], patch_size)): - # Perturb input patch-wise. - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - patch_slice = utils.create_patch_slice( - patch_size=patch_size, - coords=[top_left_x, top_left_y], - ) - - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - # Predict on perturbed input x_perturbed. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - warn.warn_perturbation_caused_no_change( - x=x, x_perturbed=x_input - ) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - x_diff = x - x_perturbed - a_diff = np.dot( - np.repeat(a, repeats=self.nr_channels, axis=0), x_diff - ) - - pred_deltas[i_x][i_y] = y_pred - y_pred_perturb - a_sums[i_x][i_y] = np.sum(a_diff) - - assert callable(self.loss_func) - sub_results.append( - self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) - ) - - results.append(np.mean(sub_results)) - - return np.mean(results) - def custom_preprocess( self, x_batch: np.ndarray, @@ -429,8 +335,70 @@ def evaluate_batch( scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + results = [] + for _ in range(self.n_perturb_samples): + sub_results = [] + for patch_size in self.perturb_patch_sizes: + pred_deltas = [] + a_sums = [] + x_perturbed = x_batch.copy() + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + # Pad the input + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, patch_size), utils.get_padding_size( + x_perturbed_w, patch_size + ) + x_perturbed_pad = utils._pad_array( + x_perturbed, + ((0, 0), (0, 0), padding_h, padding_w), + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for x_indices in utils.get_block_indices(x_perturbed_pad, patch_size): + # Perturb input by block indices of certain patch size + x_perturbed_pad = self.perturb_func( + arr=x_perturbed_pad.reshape(batch_size, -1), + indices=x_indices, + ) + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + x_diff = x_batch - x_perturbed + a_diff = a_batch * x_diff.reshape(batch_size, -1) + + pred_deltas.append(y_pred - y_pred_perturb) + a_sums.append(np.sum(a_diff, axis=-1)) + + pred_deltas = np.stack(pred_deltas, axis=1) + a_sums = np.stack(a_sums, axis=1) + assert callable(self.loss_func) + sub_results.append(self.loss_func(a=pred_deltas, b=a_sums, batched=True)) + results.append(np.mean(np.stack(sub_results, axis=1), axis=-1)) + results = np.stack(results, axis=1) + return np.mean(results, axis=-1) diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 2b0650cdd..fbf972f30 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -10,7 +10,7 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import baseline_replacement_by_mask from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( DataType, @@ -126,14 +126,12 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = baseline_replacement_by_mask # Save metric-specific attributes. self.segmentation_method = segmentation_method self.nr_channels = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -259,79 +257,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - Returns - ------- - float - The evaluation results. - """ - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - # Segment image. - segments = utils.get_superpixel_segments( - img=np.moveaxis(x, 0, -1).astype("double"), - segmentation_method=self.segmentation_method, - ) - nr_segments = len(np.unique(segments)) - asserts.assert_nr_segments(nr_segments=nr_segments) - - # Calculate average attribution of each segment. - att_segs = np.zeros(nr_segments) - for i, s in enumerate(range(nr_segments)): - att_segs[i] = np.mean(a[:, segments == s]) - - # Sort segments based on the mean attribution (descending order). - s_indices = np.argsort(-att_segs) - - preds = [] - x_prev_perturbed = x - - for i_ix, s_ix in enumerate(s_indices): - # Perturb input by indices of attributions. - a_ix = np.nonzero((segments == s_ix).flatten())[0] - - x_perturbed = self.perturb_func( - arr=x_prev_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change( - x=x_prev_perturbed, x_perturbed=x_perturbed - ) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - # Normalise the scores to be within range [0, 1]. - preds.append(float(y_pred_perturb / y_pred)) - x_prev_perturbed = x_perturbed - - # Calculate the area over the curve (AOC) score. - aoc = len(preds) - utils.calculate_auc(np.array(preds)) - return aoc - def custom_preprocess( self, x_batch: np.ndarray, @@ -397,7 +322,68 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + # Segment image. + segments_batch = [] + s_indices_batch = [] + for x, a in zip(x_batch, a_batch): + segments = utils.get_superpixel_segments( + img=np.moveaxis(x, 0, -1).astype("double"), + segmentation_method=self.segmentation_method, + ) + nr_segments = len(np.unique(segments)) + asserts.assert_nr_segments(nr_segments=nr_segments) + segments_batch.append(segments) + + # Calculate average attribution of each segment. + att_segs = np.zeros(nr_segments) + for i, s in enumerate(range(nr_segments)): + att_segs[i] = np.mean(a[:, segments == s]) + + # Sort segments based on the mean attribution (descending order). + s_indices = np.argsort(-att_segs) + s_indices_batch.append(s_indices) + segments_batch = np.stack(segments_batch, axis=0) + max_segments_len = max([len(s_indices) for s_indices in s_indices_batch]) + mask_preds_batch = np.array( + [[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch] + ) + s_indices_batch = np.array( + [s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch] + ) + + preds = [] + x_perturbed = x_batch.copy() + for s_indices_segment in s_indices_batch.T: + # Perturb input by indices of attributions. + mask = (segments_batch == s_indices_segment[:, None, None])[:, None] + + x_new_perturbed = self.perturb_func( + arr=x_perturbed, + mask=mask, + ) + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_new_perturbed, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_new_perturbed, x_new_perturbed.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + # Normalise the scores to be within range [0, 1]. + preds.append(y_pred_perturb / y_pred) + x_perturbed = x_new_perturbed + preds = np.stack(preds, axis=1) * mask_preds_batch + # Calculate the area over the curve (AOC) score. + aoc = mask_preds_batch.sum(-1) - utils.calculate_auc(preds, batched=True) + return aoc diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 7a7370e25..829c9b75f 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -12,7 +12,9 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -140,15 +142,13 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.patch_size = patch_size self.order = order.lower() self.regions_evaluation = regions_evaluation - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. asserts.assert_attributions_order(order=self.order) @@ -276,178 +276,137 @@ def __call__( **kwargs, ) - def evaluate_instance( + @property + def get_auc_score(self): + """Calculate the area under the curve (AUC) score for several test samples.""" + return np.mean([utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores]) + + def evaluate_batch( self, model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[List[float]]: """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- model: ModelInterface A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. Returns ------- - : list + scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + x_batch_shape = x_batch.shape + if len(x_batch.shape) == 3: + x_batch = x_batch[:, :, None] + a_batch = a_batch[:, :, None] - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) + batch_size = a_batch.shape[0] + # Predict on input. + x_input = model.shape_input(x_batch, x_batch_shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] patches = [] - x_perturbed = x.copy() + x_perturbed = x_batch.copy() # Pad input and attributions. This is needed to allow for any patch_size. - pad_width = self.patch_size - 1 - x_pad = utils._pad_array(x, pad_width, mode="constant", padded_axes=self.a_axes) - a_pad = utils._pad_array(a, pad_width, mode="constant", padded_axes=self.a_axes) + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, self.patch_size), utils.get_padding_size( + x_perturbed_w, self.patch_size + ) + padding = ((0, 0), (0, 0), padding_h, padding_w) + x_pad = utils._pad_array( + x_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + a_pad = utils._pad_array( + a_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) # Create patches across whole input shape and aggregate attributions. att_sums = [] - axis_iterators = [ - range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes - ] - for top_left_coords in itertools.product(*axis_iterators): + patches = [] + for block_indices in utils.get_block_indices(x_pad, self.patch_size): # Create slice for patch. - patch_slice = utils.create_patch_slice( - patch_size=self.patch_size, - coords=top_left_coords, - ) + a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) # Sum attributions for patch. - att_sums.append( - a_pad[utils.expand_indices(a_pad, patch_slice, self.a_axes)].sum() - ) - patches.append(patch_slice) + att_sums.append(a_sum) + patches.append(block_indices) + att_sums = np.stack(att_sums, -1) + patches = np.stack(patches, 1) if self.order == "random": # Order attributions randomly. - order = np.arange(len(patches)) - np.random.shuffle(order) + order = np.array([np.random.permutation(patches.shape[1]) for _ in range(batch_size)]) elif self.order == "morf": # Order attributions according to the most relevant first. - order = np.argsort(att_sums)[::-1] + order = np.argsort(-att_sums, -1) elif self.order == "lerf": # Order attributions according to the least relevant first. - order = np.argsort(att_sums) + order = np.argsort(att_sums, -1) else: - raise ValueError( - "Chosen order must be in ['random', 'morf', 'lerf'] but is: {self.order}." - ) + raise ValueError("Chosen order must be in ['random', 'morf', 'lerf'] but is: {self.order}.") # Create ordered list of patches. - ordered_patches = [patches[p] for p in order] - - # Remove overlapping patches - blocked_mask = np.zeros(x_pad.shape, dtype=bool) - ordered_patches_no_overlap = [] - for patch_slice in ordered_patches: - patch_mask = np.zeros(x_pad.shape, dtype=bool) - patch_mask[ - utils.expand_indices(patch_mask, patch_slice, self.a_axes) - ] = True - # patch_mask_exp = utils.expand_indices(patch_mask, patch_slice, self.a_axes) - # patch_mask[patch_mask_exp] = True - intersected = blocked_mask & patch_mask - - if not intersected.any(): - ordered_patches_no_overlap.append(patch_slice) - blocked_mask = blocked_mask | patch_mask - - if len(ordered_patches_no_overlap) >= self.regions_evaluation: - break - - # Warn - warn.warn_iterations_exceed_patch_number( - self.regions_evaluation, len(ordered_patches_no_overlap) - ) + ordered_patches = patches[np.arange(batch_size)[:, None], order].transpose(1, 0, 2) # Increasingly perturb the input and store the decrease in function value. - results = [None for _ in range(len(ordered_patches_no_overlap))] - for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should probably depend on the used perturb_func? - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - + results = [] + x_perturbed_pad = utils._pad_array( + x_perturbed, + padding, + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for patch_slice in ordered_patches[: self.regions_evaluation]: # Perturb. - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) + x_perturbed_pad = self.perturb_func(arr=x_perturbed_pad.reshape(batch_size, -1), indices=patch_slice) # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - results[patch_id] = y_pred - y_pred_perturb - + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch_shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + results.append(y_pred - y_pred_perturb) + results = np.stack(results, 1) return results - - @property - def get_auc_score(self): - """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[List[float]]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 5e25f0dda..d9b348567 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -132,9 +132,7 @@ def __init__( self.percentages = percentages self.a_size = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, noise=noise - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, noise=noise) # Asserts and warnings. if not self.disable_warnings: @@ -256,57 +254,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - list: - The evaluation results. - """ - # Order indices. - ordered_indices = np.argsort(a, axis=None)[::-1] - - results_instance = np.array([None for _ in self.percentages]) - - for p_ix, p in enumerate(self.percentages): - top_k_indices = ordered_indices[: int(self.a_size * p / 100)] - - x_perturbed = self.perturb_func( # type: ignore - arr=x, - indices=top_k_indices, - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - class_pred_perturb = np.argmax(model.predict(x_input)) - - # Write a boolean into the percentage results. - results_instance[p_ix] = int(y == class_pred_perturb) - - # Return list of booleans for each percentage. - return results_instance - def custom_batch_preprocess(self, a_batch: np.ndarray, **kwargs) -> None: """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" if self.a_size is None: @@ -330,9 +277,10 @@ def custom_postprocess( """ # Calculate accuracy for every number of most important pixels removed. + percentage_scores = self.evaluation_scores[0].mean(axis=0) self.evaluation_scores = { - percentage: np.mean(np.array(self.evaluation_scores)[:, p_ix]) - for p_ix, percentage in enumerate(self.percentages) + percentage: percentage_score + for p_ix, (percentage, percentage_score) in enumerate(zip(self.percentages, percentage_scores)) } def evaluate_batch( @@ -365,7 +313,38 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + ordered_indices = np.argsort(-a_batch, axis=1) + results = [] + for p in self.percentages: + top_k_indices = ordered_indices[:, : int(self.a_size * p / 100)] + + x_perturbed = [] + for x_element, top_k_index in zip(x_batch, top_k_indices): + x_perturbed_element = self.perturb_func( # type: ignore + arr=x_element, + indices=top_k_index, + ) + x_perturbed.append(x_perturbed_element) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed) + x_perturbed = np.stack(x_perturbed, axis=0) + + # Predict on perturbed input x and store the difference from predicting on unperturbed input. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + class_pred_perturb = np.argmax(model.predict(x_input), axis=-1) + + # Write a boolean into the percentage results. + results.append(y_batch == class_pred_perturb) + results = np.stack(results, axis=1).astype(int) + # print(results_instance) + # Return list of booleans for each percentage. + return [results] diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 2058d05e4..35b7b8358 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -12,7 +12,7 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import baseline_replacement_by_indices, batch_baseline_replacement_by_indices from quantus.helpers import plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -134,22 +134,17 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.patch_size = patch_size - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline' and the patch size for masking" - " 'patch_size'" - ), + sensitive_params=("baseline value 'perturb_baseline' and the patch size for masking" " 'patch_size'"), data_domain_applicability=( f"Also, the current implementation only works for 3-dimensional (image) data." ), @@ -266,117 +261,10 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - : list - The evaluation results. - """ - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - patches = [] - x_perturbed = x.copy() - - # Pad input and attributions. This is needed to allow for any patch_size. - pad_width = self.patch_size - 1 - x_pad = utils._pad_array(x, pad_width, mode="constant", padded_axes=self.a_axes) - a_pad = utils._pad_array(a, pad_width, mode="constant", padded_axes=self.a_axes) - - # Get patch indices of sorted attributions (descending). - att_sums = [] - axis_iterators = [ - range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes - ] - for top_left_coords in itertools.product(*axis_iterators): - # Create slice for patch. - patch_slice = utils.create_patch_slice( - patch_size=self.patch_size, - coords=top_left_coords, - ) - - # Sum attributions for patch. - att_sums.append( - a_pad[utils.expand_indices(a_pad, patch_slice, self.a_axes)].sum() - ) - patches.append(patch_slice) - - # Create ordered list of patches. - ordered_patches = [patches[p] for p in np.argsort(att_sums)[::-1]] - - # Remove overlapping patches. - blocked_mask = np.zeros(x_pad.shape, dtype=bool) - ordered_patches_no_overlap = [] - for patch_slice in ordered_patches: - patch_mask = np.zeros(x_pad.shape, dtype=bool) - patch_mask_exp = utils.expand_indices(patch_mask, patch_slice, self.a_axes) - patch_mask[patch_mask_exp] = True - intersected = blocked_mask & patch_mask - - if not intersected.any(): - ordered_patches_no_overlap.append(patch_slice) - blocked_mask = blocked_mask | patch_mask - - # Increasingly perturb the input and store the decrease in function value. - results = np.array([None for _ in range(len(ordered_patches_no_overlap))]) - for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should depend on the used perturb_func. - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - - # Perturb the input. - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - results[patch_id] = y_pred_perturb - - return results - @property def get_auc_score(self): """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [ - utils.calculate_auc(np.array(curve)) - for i, curve in enumerate(self.evaluation_scores) - ] - ) + return np.mean([utils.calculate_auc(np.array(curve)) for i, curve in enumerate(self.evaluation_scores)]) def evaluate_batch( self, @@ -408,7 +296,86 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + x_batch_shape = x_batch.shape + if len(x_batch.shape) == 3: + x_batch = x_batch[:, :, None] + a_batch = a_batch[:, :, None] + + batch_size = a_batch.shape[0] + + patches = [] + x_perturbed = x_batch.copy() + + # Pad input and attributions. This is needed to allow for any patch_size. + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, self.patch_size), utils.get_padding_size( + x_perturbed_w, self.patch_size + ) + padding = ((0, 0), (0, 0), padding_h, padding_w) + x_pad = utils._pad_array( + x_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + a_pad = utils._pad_array( + a_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + + # Get patch indices of sorted attributions (descending). + att_sums = [] + patches = [] + for block_indices in utils.get_block_indices(x_pad, self.patch_size): + # Create slice for patch. + a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) + + # Sum attributions for patch. + att_sums.append(a_sum) + patches.append(block_indices) + att_sums = np.stack(att_sums, -1) + patches = np.stack(patches, 1) + + # Create ordered list of patches. + order = np.argsort(-att_sums, -1) + ordered_patches = patches[np.arange(batch_size)[:, None], order].transpose(1, 0, 2) + + # Increasingly perturb the input and store the decrease in function value. + results = [] + x_perturbed_pad = utils._pad_array( + x_perturbed, + padding, + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for patch_slice in ordered_patches: + # Perturb. + x_perturbed_pad = self.perturb_func(arr=x_perturbed_pad.reshape(batch_size, -1), indices=patch_slice) + + # Remove padding. + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch_shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + results.append(y_pred_perturb) + results = np.stack(results, 1, dtype=np.float64) + + return results diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 6abafee83..363645d0c 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -10,9 +10,12 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math from quantus.functions.normalise_func import normalise_by_max -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, plotting, warn from quantus.helpers.enums import ( @@ -141,7 +144,7 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. if similarity_func is None: @@ -149,9 +152,7 @@ def __init__( self.similarity_func = similarity_func self.n_max_percentage = n_max_percentage self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -275,69 +276,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> Dict[str, List[float]]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - (Dict[str, List[float]]): The evaluation results. - """ - - # Reshape the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - att_sums = [] - pred_deltas = [] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Sum attributions. - att_sums.append(float(a[a_ix].sum())) - - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(y_pred - y_pred_perturb) - - # Each list-element of self.evaluation_scores will be such a dictionary - # We will unpack that later in custom_postprocess(). - return {"att_sums": att_sums, "pred_deltas": pred_deltas} - def custom_preprocess( self, x_batch: np.ndarray, @@ -382,35 +320,18 @@ def custom_postprocess( ------- None """ - max_features = int( - self.n_max_percentage * np.prod(x_batch.shape[2:]) // self.features_in_step - ) + max_features = int(self.n_max_percentage * np.prod(x_batch.shape[2:]) // self.features_in_step) # Get pred_deltas and att_sums from result list. - sub_results_pred_deltas: List[Any] = [ - r["pred_deltas"] for r in self.evaluation_scores - ] - sub_results_att_sums: List[Any] = [ - r["att_sums"] for r in self.evaluation_scores - ] - - # Re-arrange sub-lists so that they are sorted by n. - sub_results_pred_deltas_l: Dict[int, Any] = {k: [] for k in range(max_features)} - sub_results_att_sums_l: Dict[int, Any] = {k: [] for k in range(max_features)} - - for k in range(max_features): - for pred_deltas_instance in sub_results_pred_deltas: - sub_results_pred_deltas_l[k].append(pred_deltas_instance[k]) - for att_sums_instance in sub_results_att_sums: - sub_results_att_sums_l[k].append(att_sums_instance[k]) + pred_deltas: np.array = self.evaluation_scores[0]["pred_deltas"] + att_sums: np.array = self.evaluation_scores[0]["att_sums"] # Compute the similarity for each n. - self.evaluation_scores = [ - self.similarity_func( - a=sub_results_att_sums_l[k], b=sub_results_pred_deltas_l[k] - ) - for k in range(max_features) - ] + self.evaluation_scores = self.similarity_func( + a=pred_deltas[:, :max_features].T, + b=att_sums[:, :max_features].T, + batched=True, + ) def evaluate_batch( self, @@ -442,7 +363,53 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Predict on x. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + n_perturbations = math.ceil(n_features / self.features_in_step) + pred_deltas = [] + att_sums = [] + x_batch_shape = x_batch.shape + x_perturbed = x_batch.copy() + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_perturbed.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Sum attributions. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + # Each list-element of self.evaluation_scores will be such a dictionary + # We will unpack that later in custom_postprocess(). + return [{"att_sums": att_sums, "pred_deltas": pred_deltas}] diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 4d2fb9ca9..4de52c3b9 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -249,40 +249,6 @@ def __call__( **kwargs, ) - @staticmethod - def evaluate_instance( - i: int, - a_sim_vector: np.ndarray, - y_pred_classes: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - i: int - The index of the current instance. - a_sim_vector: any - The custom input to be evaluated on an instance-basis. - y_pred_classes: np,ndarray - The class predictions of the complete input dataset. - - Returns - ------- - float - The evaluation results. - """ - - # Metric logic. - pred_a = y_pred_classes[i] - low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() - low_dist_a = low_dist_a[low_dist_a != i] - pred_low_dist_a = y_pred_classes[low_dist_a] - - if len(low_dist_a) == 0: - return 0 - return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) - def custom_batch_preprocess( self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: @@ -295,9 +261,7 @@ def custom_batch_preprocess( a_sim_matrix[dist_matrix <= self.threshold] = 1 # Predict on input. - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() return { @@ -334,10 +298,11 @@ def evaluate_batch( evaluation_scores: List of measured sufficiency for each entry in the batch. """ - - return [ - self.evaluate_instance( - i=i, a_sim_vector=a_sim_vector, y_pred_classes=y_pred_classes - ) - for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) - ] + batch_size = y_pred_classes.shape[0] + # Metric logic. + a_sim_vector_batch *= 1 - np.eye(batch_size, dtype=np.int32) + pred_classes_equality = y_pred_classes[None] == y_pred_classes[:, None] + low_dist_freqs = a_sim_vector_batch.sum(-1) + return (pred_classes_equality * a_sim_vector_batch).sum(-1) / np.where( + low_dist_freqs == 0, np.inf, low_dist_freqs + ) diff --git a/tests/conftest.py b/tests/conftest.py index 6e87b554b..0b3b3efd8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,13 +6,17 @@ import pytest import torch from keras.datasets import cifar10 -from quantus.helpers.model.models import (CifarCNNModel, ConvNet1D, - ConvNet1DTF, LeNet, LeNetTF, - TitanicSimpleTFModel, - TitanicSimpleTorchModel) +from quantus.helpers.model.models import ( + CifarCNNModel, + ConvNet1D, + ConvNet1DTF, + LeNet, + LeNetTF, + TitanicSimpleTFModel, + TitanicSimpleTorchModel, +) from sklearn.model_selection import train_test_split -from transformers import (AutoModelForSequenceClassification, AutoTokenizer, - set_seed) +from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed CIFAR_IMAGE_SIZE = 32 MNIST_IMAGE_SIZE = 28 @@ -20,7 +24,8 @@ MINI_BATCH_SIZE = 8 RANDOM_SEED = 42 -@pytest.fixture(scope='function', autouse=True) + +@pytest.fixture(scope="function", autouse=True) def reset_prngs(): set_seed(42) @@ -29,9 +34,7 @@ def reset_prngs(): def load_mnist_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNet() - model.load_state_dict( - torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", weights_only=True)) return model @@ -58,7 +61,7 @@ def load_1d_1ch_conv_model(): model.eval() # TODO: add trained model weights # model.load_state_dict( - # torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) + # torch.load("tests/assets/mnist", map_location="cpu", weights_only=True) # ) return model @@ -70,7 +73,7 @@ def load_1d_3ch_conv_model(): model.eval() # TODO: add trained model weights # model.load_state_dict( - # torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) + # torch.load("tests/assets/mnist", map_location="cpu", pweights_only=True) # ) return model @@ -89,9 +92,7 @@ def load_1d_3ch_conv_model_tf(): def load_mnist_images(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" x_batch = ( - np.loadtxt("tests/assets/mnist_x") - .astype(float) - .reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)) + np.loadtxt("tests/assets/mnist_x").astype(float).reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)) )[:MINI_BATCH_SIZE] y_batch = np.loadtxt("tests/assets/mnist_y").astype(int)[:MINI_BATCH_SIZE] return {"x_batch": x_batch, "y_batch": y_batch} @@ -101,11 +102,9 @@ def load_mnist_images(): def load_cifar10_images(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" (x_train, y_train), (_, _) = cifar10.load_data() - x_batch = ( - x_train[:BATCH_SIZE] - .reshape((BATCH_SIZE, 3, CIFAR_IMAGE_SIZE, CIFAR_IMAGE_SIZE)) - .astype(float) - )[:MINI_BATCH_SIZE] + x_batch = (x_train[:BATCH_SIZE].reshape((BATCH_SIZE, 3, CIFAR_IMAGE_SIZE, CIFAR_IMAGE_SIZE)).astype(float))[ + :MINI_BATCH_SIZE + ] y_batch = y_train[:BATCH_SIZE].reshape(-1).astype(int)[:MINI_BATCH_SIZE] return {"x_batch": x_batch, "y_batch": y_batch} @@ -185,7 +184,7 @@ def flat_sequence_array(): @pytest.fixture(scope="session", autouse=True) def titanic_model_torch(): model = TitanicSimpleTorchModel() - model.load_state_dict(torch.load("tests/assets/titanic_model_torch.pickle")) + model.load_state_dict(torch.load("tests/assets/titanic_model_torch.pickle", weights_only=True)) return model @@ -244,9 +243,7 @@ def load_hf_distilbert_sequence_classifier(): TODO """ DISTILBERT_BASE = "distilbert-base-uncased" - model = AutoModelForSequenceClassification.from_pretrained( - DISTILBERT_BASE, cache_dir="/tmp/" - ) + model = AutoModelForSequenceClassification.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") return model @@ -264,4 +261,4 @@ def dummy_hf_tokenizer(): @pytest.fixture(scope="session", autouse=True) def set_env(): """Set ENV var, so test outputs are not polluted by progress bars and warnings.""" - os.environ["PYTEST"] = "1" \ No newline at end of file + os.environ["PYTEST"] = "1" diff --git a/tests/metrics/test_faithfulness_metrics.py b/tests/metrics/test_faithfulness_metrics.py index d8b8b690c..3a72d1df6 100644 --- a/tests/metrics/test_faithfulness_metrics.py +++ b/tests/metrics/test_faithfulness_metrics.py @@ -292,9 +292,7 @@ def test_faithfulness_correlation( **call_params, )[0] - assert np.all( - ((scores >= expected["min"]) & (scores <= expected["max"])) - ), "Test failed." + assert np.all(((scores >= expected["min"]) & (scores <= expected["max"]))), "Test failed." @pytest.mark.faithfulness @@ -462,9 +460,7 @@ def test_faithfulness_estimate( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -598,9 +594,7 @@ def test_iterative_removal_of_features( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -1066,13 +1060,7 @@ def test_pixel_flipping( **call_params, ) - assert all( - [ - (s >= expected["min"] and s <= expected["max"]) - for s_list in scores - for s in s_list - ] - ), "Test failed." + assert all([(s >= expected["min"] and s <= expected["max"]) for s_list in scores for s in s_list]), "Test failed." @pytest.mark.faithfulness @@ -1134,7 +1122,7 @@ def test_pixel_flipping( "normalise": True, "order": "morf", "disable_warnings": True, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, }, "call": { "explain_func": explain, @@ -1240,13 +1228,7 @@ def test_region_perturbation( **call_params, ) - assert all( - [ - (s >= expected["min"] and s <= expected["max"]) - for s_list in scores - for s in s_list - ] - ), "Test failed." + assert all([(s >= expected["min"] and s <= expected["max"]) for s_list in scores for s in s_list]), "Test failed." @pytest.mark.faithfulness @@ -1615,9 +1597,7 @@ def test_sensitivity_n( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -1629,7 +1609,7 @@ def test_sensitivity_n( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "return_aggregate": False, "normalise": True, "abs": True, @@ -1652,7 +1632,7 @@ def test_sensitivity_n( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "return_aggregate": False, "normalise": True, "abs": True, diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index ac86b2a7b..7a43aabe4 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -324,9 +324,7 @@ def load_artificial_attribution(): def load_mnist_adaptive_lenet_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNetAdaptivePooling(input_shape=(1, 28, 28)) - model.load_state_dict( - torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", weights_only=True)) return model @@ -349,9 +347,7 @@ def load_mnist_mosaics(load_mnist_images): def load_cifar10_adaptive_lenet_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNetAdaptivePooling(input_shape=(3, 32, 32)) - model.load_state_dict( - torch.load("tests/assets/cifar10", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/cifar10", map_location="cpu", weights_only=True)) return model @@ -917,9 +913,7 @@ def test_relevance_mass_accuracy( print(scores) if isinstance(expected, float): - assert ( - all(round(s, 2) == round(expected, 2) for s in scores) == True - ), "Test failed." + assert all(round(s, 2) == round(expected, 2) for s in scores) == True, "Test failed." elif "type" in expected: assert isinstance(scores, expected["type"]), "Test failed." else: From 67c8c37f7ee991c5536337f906c4dafcbd2dec15 Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Sat, 21 Sep 2024 09:40:17 +0200 Subject: [PATCH 07/11] small corrections for tests --- quantus/metrics/robustness/consistency.py | 4 +--- tests/conftest.py | 2 +- tests/functions/test_pytorch_model.py | 22 ++++++---------------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index ea782d2db..d28f9b183 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -277,9 +277,7 @@ def custom_batch_preprocess( self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Consistency on batch-level.""" - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) a_labels = np.array(list(map(self.discretise_func, a_batch_flat))) diff --git a/tests/conftest.py b/tests/conftest.py index 0b3b3efd8..bde949ac9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -254,7 +254,7 @@ def dummy_hf_tokenizer(): """ DISTILBERT_BASE = "distilbert-base-uncased" REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog" - tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") + tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/", clean_up_tokenization_spaces=True) return tokenizer(REFERENCE_TEXT, return_tensors="pt") diff --git a/tests/functions/test_pytorch_model.py b/tests/functions/test_pytorch_model.py index cbd240458..caa64f501 100644 --- a/tests/functions/test_pytorch_model.py +++ b/tests/functions/test_pytorch_model.py @@ -130,9 +130,7 @@ def test_get_softmax_arg_model( ), ], ) -def test_predict( - data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model -): +def test_predict(data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model): load_mnist_model.eval() training = params.pop("training", False) model = PyTorchModel(load_mnist_model, **params) @@ -171,9 +169,7 @@ def test_predict( ), ], ) -def test_shape_input( - data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model -): +def test_shape_input(data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model): load_mnist_model.eval() model = PyTorchModel(load_mnist_model, channel_first=params["channel_first"]) if not params["channel_first"]: @@ -266,9 +262,7 @@ def test_add_mean_shift_to_first_layer(load_mnist_model): ( lazy_fixture("load_hf_distilbert_sequence_classifier"), { - "input_ids": torch.tensor( - [[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]] - ), + "input_ids": torch.tensor([[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]), "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), }, False, @@ -291,15 +285,11 @@ def test_add_mean_shift_to_first_layer(load_mnist_model): ), ], ) -def test_huggingface_classifier_predict( - hf_model, data, softmax, model_kwargs, expected -): - model = PyTorchModel( - model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs - ) +def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected): + model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs) with expected: out = model.predict(x=data) - assert np.allclose(out, expected.enter_result), "Test failed." + assert np.allclose(out, expected.enter_result, atol=1e-5), "Test failed." @pytest.fixture From 99ba309e77aee398cd1ba8a549e85c3d067ca90d Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Wed, 2 Oct 2024 10:07:25 +0200 Subject: [PATCH 08/11] bugfix for batch size 1 when calling pearsonr --- quantus/functions/similarity_func.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index da7f4bc18..14f768a6c 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -39,8 +39,13 @@ def correlation_spearman(a: np.array, b: np.array, batched: bool = False, **kwar assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" # Spearman correlation is not calculated row-wise like pearson. Instead it is calculated between each # pair from BOTH a and b - correlation = scipy.stats.spearmanr(a, b, axis=1)[0][: len(a), len(a) :] - return np.diag(correlation) + correlation = scipy.stats.spearmanr(a, b, axis=1)[0] + # if a and b batch size is 1, scipy returns a float instead of an array + if correlation.shape: + correlation = correlation[: len(a), len(a) :] + return np.diag(correlation) + else: + return np.array([correlation]) return scipy.stats.spearmanr(a, b)[0] From 811b99fa3c6318aa815673ab69e1748324aa65f4 Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Wed, 2 Oct 2024 10:13:24 +0200 Subject: [PATCH 09/11] removing a comment --- quantus/metrics/faithfulness/road.py | 1 - 1 file changed, 1 deletion(-) diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index d9b348567..bafaf1e22 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -345,6 +345,5 @@ def evaluate_batch( # Write a boolean into the percentage results. results.append(y_batch == class_pred_perturb) results = np.stack(results, axis=1).astype(int) - # print(results_instance) # Return list of booleans for each percentage. return [results] From 6591aab2e1f3ffc1a89e383d23cc9d4c52503243 Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Tue, 8 Oct 2024 12:05:35 +0200 Subject: [PATCH 10/11] resolving mypy tests --- quantus/functions/perturb_func.py | 4 +-- quantus/functions/similarity_func.py | 20 +++++++------- quantus/helpers/perturbation_utils.py | 7 ++--- quantus/helpers/utils.py | 4 +-- .../faithfulness/faithfulness_correlation.py | 27 +++++-------------- .../faithfulness/faithfulness_estimate.py | 27 +++++-------------- quantus/metrics/faithfulness/irof.py | 10 +++---- .../faithfulness/monotonicity_correlation.py | 27 ++++++------------- .../metrics/faithfulness/pixel_flipping.py | 25 +++++------------ .../faithfulness/region_perturbation.py | 14 +++++----- quantus/metrics/faithfulness/selectivity.py | 13 +++++---- 11 files changed, 63 insertions(+), 115 deletions(-) diff --git a/quantus/functions/perturb_func.py b/quantus/functions/perturb_func.py index a586501c9..85533ae36 100644 --- a/quantus/functions/perturb_func.py +++ b/quantus/functions/perturb_func.py @@ -357,7 +357,7 @@ def gaussian_noise( def batch_gaussian_noise( arr: np.array, - indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], + indices: np.array, perturb_mean: float = 0.0, perturb_std: float = 0.01, **kwargs, @@ -459,7 +459,7 @@ def uniform_noise( def batch_uniform_noise( arr: np.array, - indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], + indices: np.array, lower_bound: float = 0.02, upper_bound: Union[None, float] = None, **kwargs, diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 14f768a6c..e408818c5 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -7,7 +7,7 @@ # Quantus project URL: . # Quantus project URL: https://github.com/understandable-machine-intelligence-lab/Quantus -from typing import Union +from typing import Union, List import numpy as np import scipy @@ -105,7 +105,7 @@ def correlation_kendall_tau(a: np.array, b: np.array, batched: bool = False, **k return scipy.stats.kendalltau(a, b)[0] -def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float: +def distance_euclidean(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]: """ Calculate Euclidean distance of two images (or explanations). @@ -120,13 +120,13 @@ def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float: Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ return ((a - b) ** 2).sum(axis=-1) ** 0.5 -def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: +def distance_manhattan(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]: """ Calculate Manhattan distance of two images (or explanations). @@ -141,8 +141,8 @@ def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ return abs(a - b).sum(-1) @@ -272,7 +272,7 @@ def cosine(a: np.array, b: np.array, **kwargs) -> float: return scipy.spatial.distance.cosine(u=a, v=b) -def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float: +def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, List[float]]: """ Calculate Structural Similarity Index Measure of two images (or explanations). @@ -289,8 +289,8 @@ def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float: Returns ------- - float - The similarity score. + Union[float, List[float]] + The similarity score, returns a list if batched. """ def inner(aa: np.array, bb: np.array) -> float: diff --git a/quantus/helpers/perturbation_utils.py b/quantus/helpers/perturbation_utils.py index 71b3215a1..05845872e 100644 --- a/quantus/helpers/perturbation_utils.py +++ b/quantus/helpers/perturbation_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import List, TYPE_CHECKING, Callable, Mapping +from typing import List, TYPE_CHECKING, Callable, Mapping, Optional import numpy as np import functools @@ -18,11 +18,8 @@ class PerturbFunc(Protocol): def __call__( self, arr: np.ndarray, - indices: np.ndarray, - indexed_axes: np.ndarray, **kwargs, - ) -> np.ndarray: - ... + ) -> np.ndarray: ... def make_perturb_func( diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 8880f239a..3f072e30c 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -714,7 +714,7 @@ def _unpad_array( return unpadded_arr -def get_padding_size(dim: int, patch_size: int) -> Tuple[int]: +def get_padding_size(dim: int, patch_size: int) -> Tuple[int, int]: """ Calculate the padding size (optionally) needed for a patch_size. @@ -727,7 +727,7 @@ def get_padding_size(dim: int, patch_size: int) -> Tuple[int]: Returns ------- - Tuple[int] + Tuple[int, int] A tuple of values passed to the utils._pad_array method for a particular dimension. """ modulo = dim % patch_size diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 805c342ce..fe38df61f 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -147,9 +147,7 @@ def __init__( self.similarity_func = similarity_func self.nr_runs = nr_runs self.subset_size = subset_size - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -294,9 +292,7 @@ def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: returning a custom preprocess batch (custom_preprocess_batch). """ # Asserts. - asserts.assert_value_smaller_than_input_size( - x=x_batch, value=self.subset_size, value_name="subset_size" - ) + asserts.assert_value_smaller_than_input_size(x=x_batch, value=self.subset_size, value_name="subset_size") def evaluate_batch( self, @@ -338,9 +334,7 @@ def evaluate_batch( n_features = a_batch.shape[-1] # Predict on input. - x_input = model.shape_input( - x_batch, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] pred_deltas = [] @@ -351,10 +345,7 @@ def evaluate_batch( for i_ix in range(self.nr_runs): # Randomly mask by subset size. a_ix = np.stack( - [ - np.random.choice(n_features, self.subset_size, replace=False) - for _ in range(batch_size) - ], + [np.random.choice(n_features, self.subset_size, replace=False) for _ in range(batch_size)], axis=0, ) x_perturbed = self.perturb_func( @@ -365,14 +356,10 @@ def evaluate_batch( # Check if the perturbation caused change for x_element, x_perturbed_element in zip(x_batch, x_perturbed): - warn.warn_perturbation_caused_no_change( - x=x_element, x_perturbed=x_perturbed_element - ) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] pred_deltas.append(y_pred - y_pred_perturb) @@ -381,6 +368,6 @@ def evaluate_batch( pred_deltas = np.stack(pred_deltas, axis=1) att_sums = np.stack(att_sums, axis=1) - similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) return similarity.tolist() diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index 36d73622d..4cdf7aa20 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -132,18 +132,13 @@ def __init__( perturb_func = batch_baseline_replacement_by_indices self.similarity_func = similarity_func self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline' and similarity function " - "'similarity_func'" - ), + sensitive_params=("baseline value 'perturb_baseline' and similarity function " "'similarity_func'"), citation=( "Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability" " with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018)" @@ -326,9 +321,7 @@ def evaluate_batch( a_indices = np.argsort(-a_batch, axis=1) # Predict on input. - x_input = model.shape_input( - x_batch, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] n_perturbations = math.ceil(n_features / self.features_in_step) @@ -339,9 +332,7 @@ def evaluate_batch( # Perturb input by indices of attributions. a_ix = a_indices[ :, - perturbation_step_index - * self.features_in_step : (perturbation_step_index + 1) - * self.features_in_step, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, ] x_perturbed = self.perturb_func( arr=x_batch.reshape(batch_size, -1), @@ -351,14 +342,10 @@ def evaluate_batch( # Check if the perturbation caused change for x_element, x_perturbed_element in zip(x_batch, x_perturbed): - warn.warn_perturbation_caused_no_change( - x=x_element, x_perturbed=x_perturbed_element - ) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] pred_deltas.append(y_pred - y_pred_perturb) @@ -367,6 +354,6 @@ def evaluate_batch( pred_deltas = np.stack(pred_deltas, axis=1) att_sums = np.stack(att_sums, axis=1) - similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) return similarity.tolist() diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index fbf972f30..ce6aa63af 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -335,7 +335,7 @@ def evaluate_batch( # Segment image. segments_batch = [] - s_indices_batch = [] + s_indices_batch_list = [] for x, a in zip(x_batch, a_batch): segments = utils.get_superpixel_segments( img=np.moveaxis(x, 0, -1).astype("double"), @@ -352,14 +352,14 @@ def evaluate_batch( # Sort segments based on the mean attribution (descending order). s_indices = np.argsort(-att_segs) - s_indices_batch.append(s_indices) + s_indices_batch_list.append(s_indices) segments_batch = np.stack(segments_batch, axis=0) - max_segments_len = max([len(s_indices) for s_indices in s_indices_batch]) + max_segments_len = max([len(s_indices) for s_indices in s_indices_batch_list]) mask_preds_batch = np.array( - [[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch] + [[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list] ) s_indices_batch = np.array( - [s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch] + [s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list] ) preds = [] diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index d14e8f2ad..ee011e356 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -149,9 +149,7 @@ def __init__( self.eps = eps self.nr_samples = nr_samples self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -334,9 +332,7 @@ def evaluate_batch( """ # Predict on input x. - x_input = model.shape_input( - x_batch, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) batch_size = x_batch.shape[0] y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] @@ -363,9 +359,7 @@ def evaluate_batch( # Perturb input by indices of attributions. a_ix = a_indices[ :, - perturbation_step_index - * self.features_in_step : (perturbation_step_index + 1) - * self.features_in_step, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, ] y_pred_perturbs = [] @@ -378,23 +372,18 @@ def evaluate_batch( # Check if the perturbation caused change for x_element, x_perturbed_element in zip(x_batch, x_perturbed): - warn.warn_perturbation_caused_no_change( - x=x_element, x_perturbed=x_perturbed_element - ) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] y_pred_perturbs.append(y_pred_perturb) y_pred_perturbs = np.stack(y_pred_perturbs, axis=1) - vars.append( - np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred - ) + vars.append(np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred) atts.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=1)) vars = np.stack(vars, axis=1) atts = np.stack(atts, axis=1) - return self.similarity_func(a=atts, b=vars, batched=True).tolist() + similarities: np.array = self.similarity_func(a=atts, b=vars, batched=True) + return similarities.tolist() diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index 0179cb415..20330f11d 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -134,9 +134,7 @@ def __init__( # Save metric-specific attributes. self.features_in_step = features_in_step self.return_auc_per_sample = return_auc_per_sample - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -284,9 +282,7 @@ def custom_preprocess( @property def get_auc_score(self): """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] - ) + return np.mean([utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores]) def evaluate_batch( self, @@ -339,9 +335,7 @@ def evaluate_batch( # Perturb input by indices of attributions. a_ix = a_indices[ :, - perturbation_step_index - * self.features_in_step : (perturbation_step_index + 1) - * self.features_in_step, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, ] x_perturbed = self.perturb_func( arr=x_perturbed.reshape(batch_size, -1), @@ -351,19 +345,14 @@ def evaluate_batch( # Check if the perturbation caused change for x_element, x_perturbed_element in zip(x_batch, x_perturbed): - warn.warn_perturbation_caused_no_change( - x=x_element, x_perturbed=x_perturbed_element - ) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) # Predict on perturbed input x. - x_input = model.shape_input( - x_perturbed, x_batch.shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] preds.append(y_pred_perturb) - preds = np.stack(preds, axis=1) if self.return_auc_per_sample: - return utils.calculate_auc(preds, batched=True).tolist() + return utils.calculate_auc(np.stack(preds, axis=1), batched=True).tolist() - return preds.tolist() + return np.stack(preds, axis=1).tolist() diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 829c9b75f..faec84a8c 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -324,7 +324,7 @@ def evaluate_batch( # Predict on input. x_input = model.shape_input(x_batch, x_batch_shape, channel_first=True, batched=True) y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] - patches = [] + x_perturbed = x_batch.copy() # Pad input and attributions. This is needed to allow for any patch_size. @@ -347,17 +347,17 @@ def evaluate_batch( ) # Create patches across whole input shape and aggregate attributions. - att_sums = [] - patches = [] + att_sums_list = [] + patches_list = [] for block_indices in utils.get_block_indices(x_pad, self.patch_size): # Create slice for patch. a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) # Sum attributions for patch. - att_sums.append(a_sum) - patches.append(block_indices) - att_sums = np.stack(att_sums, -1) - patches = np.stack(patches, 1) + att_sums_list.append(a_sum) + patches_list.append(block_indices) + att_sums = np.stack(att_sums_list, -1) + patches = np.stack(patches_list, 1) if self.order == "random": # Order attributions randomly. diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 35b7b8358..d9197cbd7 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -306,7 +306,6 @@ def evaluate_batch( batch_size = a_batch.shape[0] - patches = [] x_perturbed = x_batch.copy() # Pad input and attributions. This is needed to allow for any patch_size. @@ -329,17 +328,17 @@ def evaluate_batch( ) # Get patch indices of sorted attributions (descending). - att_sums = [] - patches = [] + att_sums_list = [] + patches_list = [] for block_indices in utils.get_block_indices(x_pad, self.patch_size): # Create slice for patch. a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) # Sum attributions for patch. - att_sums.append(a_sum) - patches.append(block_indices) - att_sums = np.stack(att_sums, -1) - patches = np.stack(patches, 1) + att_sums_list.append(a_sum) + patches_list.append(block_indices) + att_sums = np.stack(att_sums_list, -1) + patches = np.stack(patches_list, 1) # Create ordered list of patches. order = np.argsort(-att_sums, -1) From 1609535ebbf490e5946ad8188045d7a9a42f30fa Mon Sep 17 00:00:00 2001 From: Davor Vukadin Date: Tue, 8 Oct 2024 16:31:16 +0200 Subject: [PATCH 11/11] resolving the single missing mypy test --- quantus/metrics/randomisation/random_logit.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index cc42db9bc..728a5a680 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -267,16 +267,11 @@ def evaluate_instance( """ # Randomly select off-class labels. np.random.seed(self.seed) - y_off = np.array( - [ - np.random.choice( - [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] - ) - ] - ) + y_off = np.array([np.random.choice([y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y])]) # Explain against a random class. a_perturbed = self.explain_batch(model, np.expand_dims(x, axis=0), y_off) - return self.similarity_func(a.flatten(), a_perturbed.flatten()) + similarity = float(self.similarity_func(a.flatten(), a_perturbed.flatten())) + return similarity def custom_preprocess( self, @@ -328,7 +323,4 @@ def evaluate_batch( scores_batch: Evaluation results. """ - return [ - self.evaluate_instance(model, x, y, a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + return [self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch)]