diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index 69181e1f..1ed646fd 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -9,7 +9,7 @@ import numpy as np -def mse(a: np.array, b: np.array, **kwargs) -> float: +def mse(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float: """ Calculate Mean Squared Error between two images (or explanations). @@ -19,6 +19,8 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: Array to calculate MSE with. b: np.ndarray Array to calculate MSE with. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. normalise_mse: boolean @@ -31,10 +33,10 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: """ normalise = kwargs.get("normalise_mse", False) - + axis = -1 if batched else 0 if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) + return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=axis) # If no need to normalise, return (a-b)^2. - return np.average(((a - b) ** 2), axis=0) + return np.average(((a - b) ** 2), axis=axis) diff --git a/quantus/functions/norm_func.py b/quantus/functions/norm_func.py index 3d47bbdc..94ef0b67 100644 --- a/quantus/functions/norm_func.py +++ b/quantus/functions/norm_func.py @@ -16,15 +16,15 @@ def fro_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the Frobenius on. + The array to calculate the Frobenius on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'fro_norm' receives a 1D array." - return np.linalg.norm(a) + assert a.ndim == 1 or a.ndim == 2, "Check that 'fro_norm' receives a 1D or 2D array." + return np.linalg.norm(a, axis=-1) def l2_norm(a: np.array) -> float: @@ -34,15 +34,15 @@ def l2_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the L2 on + The array to calculate the L2 on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'l2_norm' receives a 1D array." - return np.linalg.norm(a) + assert a.ndim == 1 or a.ndim == 2, "Check that 'l2_norm' receives a 1D array." + return np.linalg.norm(a, axis=-1) def linf_norm(a: np.array) -> float: @@ -52,12 +52,12 @@ def linf_norm(a: np.array) -> float: Parameters ---------- a: np.ndarray - The array to calculate the L-inf on. + The array to calculate the L-inf on. If 2D, the array is assumed to be batched. Returns ------- float The norm. """ - assert a.ndim == 1, "Check that 'linf_norm' receives a 1D array." - return np.linalg.norm(a, ord=np.inf) + assert a.ndim == 1 or a.ndim == 2, "Check that 'linf_norm' receives a 1D or 2D array." + return np.linalg.norm(a, ord=np.inf, axis=-1) diff --git a/quantus/functions/perturb_func.py b/quantus/functions/perturb_func.py index 4ab35299..85533ae3 100644 --- a/quantus/functions/perturb_func.py +++ b/quantus/functions/perturb_func.py @@ -53,9 +53,7 @@ def perturb_batch( None, array """ if indices is not None: - assert arr.shape[0] == len( - indices - ), "arr and indices need same number of batches" + assert arr.shape[0] == len(indices), "arr and indices need same number of batches" if not inplace: arr = arr.copy() @@ -107,13 +105,108 @@ def baseline_replacement_by_indices( arr_perturbed = copy.copy(arr) + # Get the baseline value. + baseline_value = get_baseline_value(value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs) + + # Perturb the array. + arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes)) + + return arr_perturbed + + +def batch_baseline_replacement_by_indices( + arr: np.array, + indices: np.array, + perturb_baseline: Union[float, int, str, np.array], + **kwargs, +) -> np.array: + """ + Replace indices in an array by a given baseline. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. Shape N x F, where N is batch size and F is number of features. + indices: np.ndarray + Indices of the array to perturb. Shape N x I, where N is batch size and I is the number of indices to perturb. + perturb_baseline: float, int, str, np.ndarray + The baseline values to replace arr at indices with. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + # Get the baseline value. baseline_value = get_baseline_value( - value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs + value=perturb_baseline, + arr=arr, + return_shape=tuple(indices.shape), + batched=True, + **kwargs, ) # Perturb the array. - arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes)) + arr_perturbed[np.arange(batch_size)[:, None], indices] = baseline_value + + return arr_perturbed + + +def baseline_replacement_by_mask( + arr: np.array, + mask: np.array, + perturb_baseline: Union[float, int, str, np.array], + **kwargs, +) -> np.array: + """ + Replace indices in an array by a given baseline. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. Shape is arbitrary. + mask: np.ndarray + Boolean mask of the array to perturb. Shape must be the same as the arr. + perturb_baseline: float, int, str, np.ndarray + The baseline values to replace arr at indices with. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert arr.shape == mask.shape, "The shape of arr must be the same as the mask shape" + + arr_perturbed = copy.copy(arr) + + # Get the baseline value. + baseline_value = get_baseline_value( + value=perturb_baseline, + arr=arr, + return_shape=tuple(arr.shape), + **kwargs, + ) + + # Perturb the array. + arr_perturbed = np.where(mask, baseline_value, arr_perturbed) return arr_perturbed @@ -154,9 +247,7 @@ def baseline_replacement_by_shift( arr_perturbed = copy.copy(arr) # Get the baseline value. - baseline_value = get_baseline_value( - value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs - ) + baseline_value = get_baseline_value(value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs) # Shift the input. arr_shifted = copy.copy(arr_perturbed) @@ -264,6 +355,56 @@ def gaussian_noise( return arr_perturbed +def batch_gaussian_noise( + arr: np.array, + indices: np.array, + perturb_mean: float = 0.0, + perturb_std: float = 0.01, + **kwargs, +) -> np.array: + """ + Add gaussian noise to the input at indices. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. + indices: int, sequence, tuple + Array-like, with a subset shape of arr. + perturb_mean (float): + The mean for gaussian noise. + perturb_std (float): + The standard deviation for gaussian noise. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + + # Sample the noise. + noise = np.random.normal(loc=perturb_mean, scale=perturb_std, size=arr.shape) + + # Perturb the array. + arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[ + np.arange(batch_size)[:, None], indices + ] + + return arr_perturbed + + def uniform_noise( arr: np.array, indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]], @@ -303,9 +444,10 @@ def uniform_noise( if upper_bound is None: noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape) else: - assert upper_bound > lower_bound, ( - "Parameter 'upper_bound' needs to be larger than 'lower_bound', " - "but {} <= {}".format(upper_bound, lower_bound) + assert ( + upper_bound > lower_bound + ), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format( + upper_bound, lower_bound ) noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape) @@ -315,6 +457,66 @@ def uniform_noise( return arr_perturbed +def batch_uniform_noise( + arr: np.array, + indices: np.array, + lower_bound: float = 0.02, + upper_bound: Union[None, float] = None, + **kwargs, +) -> np.array: + """ + Add noise to the input at indices as sampled uniformly random from [-lower_bound, lower_bound]. + if upper_bound is None, and [lower_bound, upper_bound] otherwise. + + Parameters + ---------- + arr: np.ndarray + Array to be perturbed. + indices: int, sequence, tuple + Array-like, with a subset shape of arr. + lower_bound: float + The lower bound for uniform sampling. + upper_bound: float, optional + The upper bound for uniform sampling. + kwargs: optional + Keyword arguments. + + Returns + ------- + arr_perturbed: np.ndarray + The array which some of its indices have been perturbed. + """ + + # Assert dimensions + assert ( + len(arr.shape) == 2 + ), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features" + assert ( + len(indices.shape) == 2 + ), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb" + + batch_size = arr.shape[0] + arr_perturbed = copy.copy(arr) + + # Sample the noise. + if upper_bound is None: + noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape) + else: + assert ( + upper_bound > lower_bound + ), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format( + upper_bound, lower_bound + ) + noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape) + + # Perturb the array. + arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[ + np.arange(batch_size)[:, None], indices + ] + + return arr_perturbed + + def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array: """ Rotate array by some given angle, assumes image type data and channel first layout. @@ -336,8 +538,7 @@ def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array: if arr.ndim != 3: raise ValueError( - "perturb func 'rotation' requires image-type data." - "Check that this perturb_func receives a 3D array." + "perturb func 'rotation' requires image-type data." "Check that this perturb_func receives a 3D array." ) matrix = cv2.getRotationMatrix2D( @@ -523,9 +724,7 @@ def noisy_linear_imputation( a[out_off_coord_ids, variable_ids] = weight # Reduce weight for invalid coordinates. - sum_neighbors[np.argwhere(valid == 0).flatten()] = ( - sum_neighbors[np.argwhere(valid == 0).flatten()] - weight - ) + sum_neighbors[np.argwhere(valid == 0).flatten()] = sum_neighbors[np.argwhere(valid == 0).flatten()] - weight a[np.arange(len(indices)), np.arange(len(indices))] = -sum_neighbors diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 88d19a9a..e408818c 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -7,14 +7,15 @@ # Quantus project URL: . # Quantus project URL: https://github.com/understandable-machine-intelligence-lab/Quantus -from typing import Union +from typing import Union, List import numpy as np import scipy import skimage +import sys -def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: +def correlation_spearman(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Spearman rank of two images (or explanations). @@ -24,18 +25,31 @@ def correlation_spearman(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + # Spearman correlation is not calculated row-wise like pearson. Instead it is calculated between each + # pair from BOTH a and b + correlation = scipy.stats.spearmanr(a, b, axis=1)[0] + # if a and b batch size is 1, scipy returns a float instead of an array + if correlation.shape: + correlation = correlation[: len(a), len(a) :] + return np.diag(correlation) + else: + return np.array([correlation]) return scipy.stats.spearmanr(a, b)[0] -def correlation_pearson(a: np.array, b: np.array, **kwargs) -> float: +def correlation_pearson(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Pearson correlation of two images (or explanations). @@ -45,18 +59,26 @@ def correlation_pearson(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + # No axis parameter in older versions + if sys.version_info >= (3, 10): + return scipy.stats.pearsonr(a, b, axis=1)[0] + return np.array([scipy.stats.pearsonr(aa, bb)[0] for aa, bb in zip(a, b)]) return scipy.stats.pearsonr(a, b)[0] -def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: +def correlation_kendall_tau(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, np.array]: """ Calculate Kendall Tau correlation of two images (or explanations). @@ -66,18 +88,24 @@ def correlation_kendall_tau(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features kwargs: optional Keyword arguments. Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ + if batched: + assert len(a.shape) == 2 and len(b.shape) == 2, "Batched arrays must be 2D" + # No support for axis currently, so just iterating over the batch + return np.array([scipy.stats.kendalltau(a_i, b_i)[0] for a_i, b_i in zip(a, b)]) return scipy.stats.kendalltau(a, b)[0] -def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float: +def distance_euclidean(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]: """ Calculate Euclidean distance of two images (or explanations). @@ -92,13 +120,13 @@ def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float: Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ - return scipy.spatial.distance.euclidean(u=a, v=b) + return ((a - b) ** 2).sum(axis=-1) ** 0.5 -def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: +def distance_manhattan(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]: """ Calculate Manhattan distance of two images (or explanations). @@ -113,10 +141,10 @@ def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float: Returns ------- - float - The similarity score. + Union[float, np.array] + The similarity score or a batch of similarity scores. """ - return scipy.spatial.distance.cityblock(u=a, v=b) + return abs(a - b).sum(-1) def distance_chebyshev(a: np.array, b: np.array, **kwargs) -> float: @@ -145,7 +173,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs + **kwargs, ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. @@ -176,9 +204,9 @@ def lipschitz_constant( d2 = kwargs.get("norm_denominator", distance_euclidean) if np.shape(a) == (): - return float(abs(a - b) / (d2(c, d) + eps)) + return abs(a - b) / (d2(c, d) + eps) else: - return float(d1(a, b) / (d2(a=c, b=d) + eps)) + return d1(a, b) / (d2(a=c, b=d) + eps) def abs_difference(a: np.array, b: np.array, **kwargs) -> float: @@ -244,7 +272,7 @@ def cosine(a: np.array, b: np.array, **kwargs) -> float: return scipy.spatial.distance.cosine(u=a, v=b) -def ssim(a: np.array, b: np.array, **kwargs) -> float: +def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, List[float]]: """ Calculate Structural Similarity Index Measure of two images (or explanations). @@ -254,21 +282,27 @@ def ssim(a: np.array, b: np.array, **kwargs) -> float: The first array to use for similarity scoring. b: np.ndarray The second array to use for similarity scoring. + batched: bool + Whether the arrays are batched. kwargs: optional Keyword arguments. Returns ------- - float - The similarity score. + Union[float, List[float]] + The similarity score, returns a list if batched. """ - max_point, min_point = np.max(np.concatenate([a, b])), np.min( - np.concatenate([a, b]) - ) - data_range = float(np.abs(max_point - min_point)) - return skimage.metrics.structural_similarity( - im1=a, im2=b, win_size=kwargs.get("win_size", None), data_range=data_range - ) + + def inner(aa: np.array, bb: np.array) -> float: + max_point, min_point = np.max(np.concatenate([aa, bb])), np.min(np.concatenate([aa, bb])) + data_range = float(np.abs(max_point - min_point)) + return skimage.metrics.structural_similarity( + im1=aa, im2=bb, win_size=kwargs.get("win_size", None), data_range=data_range + ) + + if batched: + return [inner(aa, bb) for aa, bb in zip(a, b)] + return inner(a, b) def difference(a: np.array, b: np.array, **kwargs) -> np.array: diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index 257cf035..5c61a833 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -273,6 +273,11 @@ def shape_input( # Expand first dimension if this is just a single instance. if not batched: x = x.reshape(1, *shape) + shape = (1, *shape) + + # If shape not the same, reshape the input + if shape != x.shape: + x = x.reshape(*shape) # Set channel order according to expected input of model. if self.channel_first: @@ -314,11 +319,7 @@ def get_random_layer_generator( original_parameters = self.state_dict() random_layer_model = deepcopy(self.model) - modules = [ - layer - for layer in random_layer_model.named_modules() - if (hasattr(layer[1], "reset_parameters")) - ] + modules = [layer for layer in random_layer_model.named_modules() if (hasattr(layer[1], "reset_parameters"))] if order == "top_down": modules = modules[::-1] @@ -408,13 +409,9 @@ def add_mean_shift_to_first_layer( for i in range(module[1].out_channels): if self.channel_first: - module[1].bias[i] = torch.nn.Parameter( - 2 * module[1].bias[i] - torch.unique(fw[i])[0] - ) + module[1].bias[i] = torch.nn.Parameter(2 * module[1].bias[i] - torch.unique(fw[i])[0]) else: - module[1].bias[i] = torch.nn.Parameter( - 2 * module[1].bias[i] - torch.unique(fw[..., i])[0] - ) + module[1].bias[i] = torch.nn.Parameter(2 * module[1].bias[i] - torch.unique(fw[..., i])[0]) return new_model @@ -457,9 +454,7 @@ def get_hidden_representations( # E.g., user can provide index -1, in order to get only representations of the last layer. # E.g., for 7 layers in total, this would correspond to positive index 6. - positive_layer_indices = [ - i if i >= 0 else num_layers + i for i in layer_indices - ] + positive_layer_indices = [i if i >= 0 else num_layers + i for i in layer_indices] if layer_names is None: layer_names = [] @@ -472,9 +467,7 @@ def is_layer_of_interest(layer_index: int, layer_name: str): # skip modules defined by subclassing API. hidden_layers = list( # type: ignore filter( - lambda layer: not isinstance( - layer[1], (self.model.__class__, torch.nn.Sequential) - ), + lambda layer: not isinstance(layer[1], (self.model.__class__, torch.nn.Sequential)), all_layers, ) ) @@ -509,13 +502,7 @@ def hook(module, module_in, module_out): @property def random_layer_generator_length(self) -> int: - return len( - [ - i - for i in self.model.named_modules() - if (hasattr(i[1], "reset_parameters")) - ] - ) + return len([i for i in self.model.named_modules() if (hasattr(i[1], "reset_parameters"))]) def safe_isinstance(obj: Any, class_path_str: Union[Iterable[str], str]) -> bool: diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index a7a5f388..66162ca3 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -82,9 +82,7 @@ def _get_predict_kwargs(self, **kwargs: Dict[str, ...]) -> Dict[str, ...]: Filter out those, which are supported by Keras API. """ all_kwargs = {**self.model_predict_kwargs, **kwargs} - predict_kwargs = { - k: all_kwargs[k] for k in all_kwargs if k in self._available_predict_kwargs - } + predict_kwargs = {k: all_kwargs[k] for k in all_kwargs if k in self._available_predict_kwargs} return predict_kwargs @property @@ -214,6 +212,11 @@ def shape_input( # Expand first dimension if this is just a single instance. if not batched: x = x.reshape(1, *shape) + shape = (1, *shape) + + # If shape not the same, reshape the input + if shape != x.shape: + x = x.reshape(*shape) # Set channel order according to expected input of model. if self.channel_first: @@ -259,11 +262,7 @@ def get_random_layer_generator( original_parameters = self.state_dict() random_layer_model = clone_model(self.model) - layers = [ - _layer - for _layer in random_layer_model.layers - if len(_layer.get_weights()) > 0 - ] + layers = [_layer for _layer in random_layer_model.layers if len(_layer.get_weights()) > 0] if order == "top_down": layers = layers[::-1] @@ -277,9 +276,7 @@ def get_random_layer_generator( yield layer.name, random_layer_model @cachedmethod(operator.attrgetter("cache")) - def _build_hidden_representation_model( - self, layer_names: Tuple, layer_indices: Tuple - ) -> Model: + def _build_hidden_representation_model(self, layer_names: Tuple, layer_indices: Tuple) -> Model: """ Build a keras model, which outputs the internal representation of layers, specified in layer_names or layer_indices, default all. @@ -337,9 +334,7 @@ def add_mean_shift_to_first_layer( new_model.set_weights(original_parameters) module = new_model.layers[0] - tmp_model = Model( - inputs=[new_model.input], outputs=[new_model.layers[0].output] - ) + tmp_model = Model(inputs=[new_model.input], outputs=[new_model.layers[0].output]) delta = np.zeros(shape=shape) delta.fill(input_shift) @@ -395,9 +390,7 @@ def get_hidden_representations( # E.g., user can provide index -1, in order to get only representations of the last layer. # E.g., for 7 layers in total, this would correspond to positive index 6. - positive_layer_indices = [ - i if i >= 0 else num_layers + i for i in layer_indices - ] + positive_layer_indices = [i if i >= 0 else num_layers + i for i in layer_indices] if layer_names is None: layer_names = [] @@ -406,9 +399,7 @@ def get_hidden_representations( tuple(layer_names), tuple(positive_layer_indices) ) predict_kwargs = self._get_predict_kwargs(**kwargs) - internal_representation = hidden_representation_model.predict( - x, **predict_kwargs - ) + internal_representation = hidden_representation_model.predict(x, **predict_kwargs) input_batch_size = x.shape[0] # If we requested outputs only of 1 layer, keras will already return np.ndarray. @@ -416,9 +407,7 @@ def get_hidden_representations( if isinstance(internal_representation, np.ndarray): return internal_representation.reshape((input_batch_size, -1)) - internal_representation = [ - i.reshape((input_batch_size, -1)) for i in internal_representation - ] + internal_representation = [i.reshape((input_batch_size, -1)) for i in internal_representation] return np.hstack(internal_representation) @property diff --git a/quantus/helpers/perturbation_utils.py b/quantus/helpers/perturbation_utils.py index 71b3215a..05845872 100644 --- a/quantus/helpers/perturbation_utils.py +++ b/quantus/helpers/perturbation_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import List, TYPE_CHECKING, Callable, Mapping +from typing import List, TYPE_CHECKING, Callable, Mapping, Optional import numpy as np import functools @@ -18,11 +18,8 @@ class PerturbFunc(Protocol): def __call__( self, arr: np.ndarray, - indices: np.ndarray, - indexed_axes: np.ndarray, **kwargs, - ) -> np.ndarray: - ... + ) -> np.ndarray: ... def make_perturb_func( diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 50645979..3f072e30 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -13,6 +13,8 @@ import numpy as np from skimage.segmentation import slic, felzenszwalb +import skimage +import math from quantus.helpers import asserts from quantus.helpers.model.model_interface import ModelInterface @@ -44,13 +46,10 @@ def get_superpixel_segments(img: np.ndarray, segmentation_method: str) -> np.nda if img.ndim != 3: raise ValueError( - "Make sure that x is 3 dimensional e.g., (3, 224, 224) to calculate super-pixels." - f" shape: {img.shape}" + "Make sure that x is 3 dimensional e.g., (3, 224, 224) to calculate super-pixels." f" shape: {img.shape}" ) if segmentation_method not in ["slic", "felzenszwalb"]: - raise ValueError( - "'segmentation_method' must be either 'slic' or 'felzenszwalb'." - ) + raise ValueError("'segmentation_method' must be either 'slic' or 'felzenszwalb'.") if segmentation_method == "slic": return slic(img, start_label=0) @@ -65,6 +64,7 @@ def get_baseline_value( arr: np.ndarray, return_shape: Tuple, patch: Optional[np.ndarray] = None, + batched: bool = False, **kwargs, ) -> np.array: """ @@ -83,6 +83,8 @@ def get_baseline_value( patch: np.ndarray, optional CxWxH patch array to calculate baseline values. Necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. + batched: bool, + If True, the arr and patch are assumed to be of shape B x . The aggregations are done over the individual batch elements. kwargs: optional Keyword arguments. @@ -108,7 +110,7 @@ def get_baseline_value( ) ) elif isinstance(value, str): - fill_dict = get_baseline_dict(arr, patch, **kwargs) + fill_dict = get_baseline_dict(arr, patch, batched, **kwargs) if value.lower() == "random": raise ValueError( "'random' as a choice for 'perturb_baseline' is deprecated and has been removed from " @@ -117,17 +119,18 @@ def get_baseline_value( "which will replicate the results of 'random').\n" ) if value.lower() not in fill_dict: - raise ValueError( - f"Ensure that 'value'(string) is in {list(fill_dict.keys())}" - ) - return np.full(return_shape, fill_dict[value.lower()]) + raise ValueError(f"Ensure that 'value'(string) is in {list(fill_dict.keys())}") + fill_value = fill_dict[value.lower()] + # Expand the second dimension if batched to enable broadcasting + if batched: + fill_value = fill_value[:, None] + return np.full(return_shape, fill_value) + else: raise ValueError("Specify 'value' as a np.array, string, integer or float.") -def get_baseline_dict( - arr: np.ndarray, patch: Optional[np.ndarray] = None, **kwargs -) -> dict: +def get_baseline_dict(arr: np.ndarray, patch: Optional[np.ndarray] = None, batched: bool = False, **kwargs) -> dict: """ Make a dictionary of baseline approaches depending on the input x (or patch of input). @@ -138,6 +141,8 @@ def get_baseline_dict( patch: np.ndarray, optional CxWxH patch array to calculate baseline values, necessary for "neighbourhood_mean" and "neighbourhood_random_min_max" methods. + batched: bool + If True, the arr and patch are assumed to be of shape B x . The aggregations are done over the individual batch elements. kwargs: optional Keyword arguments.. @@ -146,20 +151,29 @@ def get_baseline_dict( fill_dict: dict Maps all available baseline methods to baseline values. """ + + # Aggregate over elements of batch + aggregation_axes = ( + tuple(range(1 if batched else 0, len(arr.shape))) + if not patch + else tuple(range(1 if batched else 0, len(patch.shape))) + ) fill_dict = { - "mean": float(arr.mean()), + "mean": arr.mean(axis=aggregation_axes), "uniform": np.random.uniform( low=kwargs.get("uniform_low", 0.0), high=kwargs.get("uniform_high", 1.0), - size=kwargs["return_shape"], + # Return only (batch_size, ) if batched to align with other fill_dict values + size=kwargs["return_shape"][0] if batched else kwargs["return_shape"], ), - "black": float(arr.min()), - "white": float(arr.max()), + "black": arr.min(axis=aggregation_axes), + "white": arr.max(axis=aggregation_axes), } if patch is not None: - fill_dict["neighbourhood_mean"] = (float(patch.mean()),) - fill_dict["neighbourhood_random_min_max"] = float( - np.random.uniform(low=patch.min(), high=patch.max()) + fill_dict["neighbourhood_mean"] = patch.mean(axis=aggregation_axes) + fill_dict["neighbourhood_random_min_max"] = np.random.uniform( + low=patch.min(axis=aggregation_axes), + high=patch.max(axis=aggregation_axes), ) return fill_dict @@ -286,9 +300,7 @@ def infer_channel_first(x: np.array) -> bool: raise ValueError(err_msg) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def make_channel_first(x: np.array, channel_first: bool = False): @@ -318,9 +330,7 @@ def make_channel_first(x: np.array, channel_first: bool = False): return np.moveaxis(x, -1, -3) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def make_channel_last(x: np.array, channel_first: bool = True): @@ -349,9 +359,7 @@ def make_channel_last(x: np.array, channel_first: bool = True): elif len(np.shape(x)) == 4: return np.moveaxis(x, -3, -1) else: - raise ValueError( - "Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels)." - ) + raise ValueError("Only batched 1D and 2D multi-channel input dimensions supported (excluding the channels).") def get_wrapped_model( @@ -429,9 +437,7 @@ def blur_at_indices( A version of arr that is blurred at indices. """ - assert kernel.ndim == len( - indexed_axes - ), "kernel should have as many dimensions as indexed_axes has elements." + assert kernel.ndim == len(indexed_axes), "kernel should have as many dimensions as indexed_axes has elements." # Pad array pad_width = [(0, 0) for _ in indexed_axes] @@ -456,9 +462,7 @@ def blur_at_indices( array_indices = np.array(array_indices) # Expand kernel dimensions - expanded_kernel = np.expand_dims( - kernel, tuple([i for i in range(arr.ndim) if i not in indexed_axes]) - ) + expanded_kernel = np.expand_dims(kernel, tuple([i for i in range(arr.ndim) if i not in indexed_axes])) # Iterate over indices, applying expanded kernel x_blur = copy.copy(x) @@ -470,9 +474,7 @@ def blur_at_indices( for ax, idx_ax in enumerate(expanded_idx): s = kernel.shape[ax] idx_ax = np.squeeze(idx_ax) - expanded_idx[ax] = slice( - idx_ax - (s // 2), idx_ax + s // 2 + 1 - (s % 2 == 0) - ) + expanded_idx[ax] = slice(idx_ax - (s // 2), idx_ax + s // 2 + 1 - (s % 2 == 0)) for dim in range(array_indices.ndim - 2): idx = [np.expand_dims(index, axis=index.ndim) for index in idx] @@ -493,9 +495,7 @@ def blur_at_indices( return _unpad_array(x_blur, pad_width, padded_axes=indexed_axes) -def create_patch_slice( - patch_size: Union[int, Sequence[int]], coords: Sequence[int] -) -> Tuple[slice, ...]: +def create_patch_slice(patch_size: Union[int, Sequence[int]], coords: Sequence[int]) -> Tuple[slice, ...]: """ Create a patch slice from patch size and coordinates. @@ -526,24 +526,48 @@ def create_patch_slice( elif patch_size.ndim != 1: raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(coords): - raise ValueError( - "patch_size sequence length does not match coords length" - f" (len(patch_size) != len(coords))" - ) + raise ValueError("patch_size sequence length does not match coords length" f" (len(patch_size) != len(coords))") # make sure that each element in tuple is integer patch_size = tuple(int(patch_size_dim) for patch_size_dim in patch_size) - patch_slice = [ - slice(coord, coord + patch_size_dim) - for coord, patch_size_dim in zip(coords, patch_size) - ] + patch_slice = [slice(coord, coord + patch_size_dim) for coord, patch_size_dim in zip(coords, patch_size)] return tuple(patch_slice) -def get_nr_patches( - patch_size: Union[int, Sequence[int]], shape: Tuple[int, ...], overlap: bool = False -) -> int: +def get_block_indices(x_batch: np.array, patch_size: int) -> np.array: + """ + Get blocks for a batch of images using a certain patch_size. + + Parameters + ---------- + x_batch: np.array + Batch of images, shape B x C x H x W + patch_size: int + Size of the patch. + + Yields + ------- + np.array + Yields blocks of image of a certain patch_size. + """ + batch_size = x_batch.shape[0] + x_indices = np.stack( + [np.arange(x.size) for x in x_batch], + axis=0, + ).reshape(*x_batch.shape) + blocks = skimage.util.view_as_blocks(x_indices, (*x_batch.shape[:2], patch_size, patch_size)) + blocks_h, blocks_w = blocks.shape[2:4] + block_indices = np.stack( + np.unravel_index(list(range(blocks_h * blocks_w)), shape=(blocks_h, blocks_w)), + axis=1, + ) + + for block_index in block_indices: + yield blocks[0, 0, block_index[0], block_index[1]].reshape(batch_size, -1) + + +def get_nr_patches(patch_size: Union[int, Sequence[int]], shape: Tuple[int, ...], overlap: bool = False) -> int: """ Get number of patches for given shape. @@ -571,10 +595,7 @@ def get_nr_patches( elif patch_size.ndim != 1: raise ValueError("patch_size has to be either a scalar or a 1D-sequence") elif len(patch_size) != len(shape): - raise ValueError( - "patch_size sequence length does not match shape length" - f" (len(patch_size) != len(shape))" - ) + raise ValueError("patch_size sequence length does not match shape length" f" (len(patch_size) != len(shape))") patch_size = tuple(patch_size) return np.prod(shape) // np.prod(patch_size) @@ -606,14 +627,10 @@ def _pad_array( Padded array. """ - assert ( - len(padded_axes) <= arr.ndim - ), "Cannot pad more axes than array has dimensions" + assert len(padded_axes) <= arr.ndim, "Cannot pad more axes than array has dimensions" if isinstance(pad_width, Sequence): - assert len(pad_width) == len( - padded_axes - ), "pad_width and padded_axes have different lengths" + assert len(pad_width) == len(padded_axes), "pad_width and padded_axes have different lengths" for p in pad_width: if isinstance(p, tuple): assert len(p) == 2, "Elements in pad_width need to have length 2" @@ -634,7 +651,6 @@ def _pad_array( ) else: pad_width_list.append(pad_width[[p for p in padded_axes].index(ax)]) - arr_pad = np.pad(arr, pad_width_list, mode=mode) return arr_pad @@ -664,14 +680,10 @@ def _unpad_array( """ - assert ( - len(padded_axes) <= arr.ndim - ), "Cannot unpad more axes than array has dimensions" + assert len(padded_axes) <= arr.ndim, "Cannot unpad more axes than array has dimensions" if isinstance(pad_width, Sequence): - assert len(pad_width) == len( - padded_axes - ), "pad_width and padded_axes have different lengths" + assert len(pad_width) == len(padded_axes), "pad_width and padded_axes have different lengths" for p in pad_width: if isinstance(p, tuple): assert len(p) == 2, "Elements in pad_width need to have length 2" @@ -702,6 +714,30 @@ def _unpad_array( return unpadded_arr +def get_padding_size(dim: int, patch_size: int) -> Tuple[int, int]: + """ + Calculate the padding size (optionally) needed for a patch_size. + + Parameters + ---------- + dim: int + Input dimension to pad (for example width or height of the image). + patch_size: int + Size of the patch for this dimension of input. + + Returns + ------- + Tuple[int, int] + A tuple of values passed to the utils._pad_array method for a particular dimension. + """ + modulo = dim % patch_size + if modulo == 0: + return (0, 0) + total_padding = patch_size - modulo + half_padding = total_padding / 2 + return (math.floor(half_padding), math.ceil(half_padding)) + + def expand_attribution_channel(a_batch: np.ndarray, x_batch: np.ndarray): """ Expand additional channel dimension(s) for attributions if needed. @@ -723,9 +759,7 @@ def expand_attribution_channel(a_batch: np.ndarray, x_batch: np.ndarray): f"a_batch and x_batch must have same number of batches ({a_batch.shape[0]} != {x_batch.shape[0]})" ) if a_batch.ndim > x_batch.ndim: - raise ValueError( - f"a must not have greater ndim than x ({a_batch.ndim} > {x_batch.ndim})" - ) + raise ValueError(f"a must not have greater ndim than x ({a_batch.ndim} > {x_batch.ndim})") if a_batch.ndim == x_batch.ndim: return a_batch @@ -764,9 +798,7 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence if a_batch.ndim > x_batch.ndim: raise ValueError( - "Attributions need to have <= dimensions than inputs, but {} > {}".format( - a_batch.ndim, x_batch.ndim - ) + "Attributions need to have <= dimensions than inputs, but {} > {}".format(a_batch.ndim, x_batch.ndim) ) # TODO: We currently assume here that the batch axis is not carried into the perturbation functions. @@ -781,17 +813,14 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence return np.array([]) x_subshapes = [ - [x_shape[i] for i in range(start, start + len(a_shape))] - for start in range(0, len(x_shape) - len(a_shape) + 1) + [x_shape[i] for i in range(start, start + len(a_shape))] for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " - "inputs were of shape {} and attributions of shape {}".format( - x_batch.shape, a_batch.shape - ) + "inputs were of shape {} and attributions of shape {}".format(x_batch.shape, a_batch.shape) ) elif x_subshapes.count(a_shape) > 1: @@ -814,17 +843,13 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence raise ValueError( "Attribution axes could not be inferred for inputs of " - "shape {} and attributions of shape {}".format( - x_batch.shape, a_batch.shape - ) + "shape {} and attributions of shape {}".format(x_batch.shape, a_batch.shape) ) raise ValueError( "Attribution dimensions are not unique subdimensions of inputs: " "inputs were of shape {} and attributions of shape {}." - "Please expand attribution dimensions for a unique solution".format( - x_batch.shape, a_batch.shape - ) + "Please expand attribution dimensions for a unique solution".format(x_batch.shape, a_batch.shape) ) else: # Infer attribution axes. @@ -898,20 +923,10 @@ def expand_indices( # Check if unraveling is needed. if np.all([isinstance(idx, int) for idx in expanded_indices]): - expanded_indices = np.unravel_index( - expanded_indices, tuple([arr.shape[i] for i in indexed_axes]) - ) - elif not np.all( - [ - isinstance(idx, np.ndarray) and idx.ndim == len(expanded_indices) - for idx in expanded_indices - ] - ): + expanded_indices = np.unravel_index(expanded_indices, tuple([arr.shape[i] for i in indexed_axes])) + elif not np.all([isinstance(idx, np.ndarray) and idx.ndim == len(expanded_indices) for idx in expanded_indices]): # Meshgrid sliced axes to account for correct slicing. Correct switched first two axes by meshgrid - expanded_indices = [ - np.swapaxes(idx, 0, 1) if idx.ndim > 1 else idx - for idx in np.meshgrid(*expanded_indices) - ] + expanded_indices = [np.swapaxes(idx, 0, 1) if idx.ndim > 1 else idx for idx in np.meshgrid(*expanded_indices)] # Handle case of 1D indices. if np.all([isinstance(idx, int) for idx in expanded_indices]): @@ -966,9 +981,7 @@ def get_leftover_shape(arr: np.array, axes: Sequence[int]) -> Tuple: return leftover_shape -def offset_coordinates( - indices: Union[list, Sequence[int], Tuple[Any]], offset: tuple, img_shape: tuple -): +def offset_coordinates(indices: Union[list, Sequence[int], Tuple[Any]], offset: tuple, img_shape: tuple): """ Checks if offset coordinates are within the image frame. Adapted from: https://github.com/tleemann/road_evaluation. @@ -999,7 +1012,7 @@ def offset_coordinates( return off_coords[valid], valid -def calculate_auc(values: np.array, dx: int = 1): +def calculate_auc(values: np.array, dx: int = 1, batched: bool = False): """ Calculate area under the curve using the composite trapezoidal rule. @@ -1015,7 +1028,8 @@ def calculate_auc(values: np.array, dx: int = 1): np.ndarray Definite integral of values. """ - return np.trapz(np.array(values), dx=dx) + axis = 1 if batched else -1 + return np.trapz(np.array(values), dx=dx, axis=axis) T = TypeVar("T") diff --git a/quantus/metrics/faithfulness/faithfulness_correlation.py b/quantus/metrics/faithfulness/faithfulness_correlation.py index 6e8e405d..fe38df61 100644 --- a/quantus/metrics/faithfulness/faithfulness_correlation.py +++ b/quantus/metrics/faithfulness/faithfulness_correlation.py @@ -10,7 +10,7 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -139,7 +139,7 @@ def __init__( # Save metric-specific attributes. if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices if similarity_func is None: similarity_func = correlation_pearson @@ -147,9 +147,7 @@ def __init__( self.similarity_func = similarity_func self.nr_runs = nr_runs self.subset_size = subset_size - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -276,65 +274,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Flatten the attributions. - a = a.flatten() - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - pred_deltas = [] - att_sums = [] - - # For each test data point, execute a couple of runs. - for i_ix in range(self.nr_runs): - # Randomly mask by subset size. - a_ix = np.random.choice(a.shape[0], self.subset_size, replace=False) - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(float(y_pred - y_pred_perturb)) - - # Sum attributions of the random subset. - att_sums.append(np.sum(a[a_ix])) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - - return similarity - def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: """ Implementation of custom_preprocess_batch. @@ -353,9 +292,7 @@ def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: returning a custom preprocess batch (custom_preprocess_batch). """ # Asserts. - asserts.assert_value_smaller_than_input_size( - x=x_batch, value=self.subset_size, value_name="subset_size" - ) + asserts.assert_value_smaller_than_input_size(x=x_batch, value=self.subset_size, value_name="subset_size") def evaluate_batch( self, @@ -387,7 +324,50 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + pred_deltas = [] + att_sums = [] + + x_batch_shape = x_batch.shape + # For each test data point, execute a couple of runs. + for i_ix in range(self.nr_runs): + # Randomly mask by subset size. + a_ix = np.stack( + [np.random.choice(n_features, self.subset_size, replace=False) for _ in range(batch_size)], + axis=0, + ) + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + + # Sum attributions of the random subset. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + + return similarity.tolist() diff --git a/quantus/metrics/faithfulness/faithfulness_estimate.py b/quantus/metrics/faithfulness/faithfulness_estimate.py index dad5fdaa..4cdf7aa2 100644 --- a/quantus/metrics/faithfulness/faithfulness_estimate.py +++ b/quantus/metrics/faithfulness/faithfulness_estimate.py @@ -9,8 +9,9 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -128,21 +129,16 @@ def __init__( if similarity_func is None: similarity_func = correlation_pearson if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices self.similarity_func = similarity_func self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline' and similarity function " - "'similarity_func'" - ), + sensitive_params=("baseline value 'perturb_baseline' and similarity function " "'similarity_func'"), citation=( "Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability" " with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018)" @@ -259,70 +255,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Flatten the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - pred_deltas = [None for _ in range(n_perturbations)] - att_sums = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas[i_ix] = float(y_pred - y_pred_perturb) - - # Sum attributions. - att_sums[i_ix] = np.sum(a[a_ix]) - - similarity = self.similarity_func(a=att_sums, b=pred_deltas) - return similarity - def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None: """ Implementation of custom_preprocess_batch. @@ -376,7 +308,52 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + n_perturbations = math.ceil(n_features / self.features_in_step) + pred_deltas = [] + att_sums = [] + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + + # Sum attributions of the random subset. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True) + + return similarity.tolist() diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index 9a3b0970..b5c07997 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -11,7 +11,9 @@ import numpy as np from quantus.functions.loss_func import mse -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.helpers import utils, warn from quantus.helpers.enums import ( DataType, @@ -148,16 +150,14 @@ def __init__( self.loss_func = loss_func if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices if perturb_patch_sizes is None: perturb_patch_sizes = [4] self.perturb_patch_sizes = perturb_patch_sizes self.n_perturb_samples = n_perturb_samples self.nr_channels = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -283,100 +283,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInterface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - results = [] - - for _ in range(self.n_perturb_samples): - sub_results = [] - - for patch_size in self.perturb_patch_sizes: - pred_deltas = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - a_sums = np.zeros( - (int(a.shape[1] / patch_size), int(a.shape[2] / patch_size)) - ) - x_perturbed = x.copy() - pad_width = patch_size - 1 - - for i_x, top_left_x in enumerate(range(0, x.shape[1], patch_size)): - for i_y, top_left_y in enumerate(range(0, x.shape[2], patch_size)): - # Perturb input patch-wise. - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - patch_slice = utils.create_patch_slice( - patch_size=patch_size, - coords=[top_left_x, top_left_y], - ) - - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - # Predict on perturbed input x_perturbed. - x_input = model.shape_input( - x_perturbed, x.shape, channel_first=True - ) - warn.warn_perturbation_caused_no_change( - x=x, x_perturbed=x_input - ) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - x_diff = x - x_perturbed - a_diff = np.dot( - np.repeat(a, repeats=self.nr_channels, axis=0), x_diff - ) - - pred_deltas[i_x][i_y] = y_pred - y_pred_perturb - a_sums[i_x][i_y] = np.sum(a_diff) - - assert callable(self.loss_func) - sub_results.append( - self.loss_func(a=pred_deltas.flatten(), b=a_sums.flatten()) - ) - - results.append(np.mean(sub_results)) - - return np.mean(results) - def custom_preprocess( self, x_batch: np.ndarray, @@ -429,8 +335,70 @@ def evaluate_batch( scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + results = [] + for _ in range(self.n_perturb_samples): + sub_results = [] + for patch_size in self.perturb_patch_sizes: + pred_deltas = [] + a_sums = [] + x_perturbed = x_batch.copy() + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + # Pad the input + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, patch_size), utils.get_padding_size( + x_perturbed_w, patch_size + ) + x_perturbed_pad = utils._pad_array( + x_perturbed, + ((0, 0), (0, 0), padding_h, padding_w), + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for x_indices in utils.get_block_indices(x_perturbed_pad, patch_size): + # Perturb input by block indices of certain patch size + x_perturbed_pad = self.perturb_func( + arr=x_perturbed_pad.reshape(batch_size, -1), + indices=x_indices, + ) + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + x_diff = x_batch - x_perturbed + a_diff = a_batch * x_diff.reshape(batch_size, -1) + + pred_deltas.append(y_pred - y_pred_perturb) + a_sums.append(np.sum(a_diff, axis=-1)) + + pred_deltas = np.stack(pred_deltas, axis=1) + a_sums = np.stack(a_sums, axis=1) + assert callable(self.loss_func) + sub_results.append(self.loss_func(a=pred_deltas, b=a_sums, batched=True)) + results.append(np.mean(np.stack(sub_results, axis=1), axis=-1)) + results = np.stack(results, axis=1) + return np.mean(results, axis=-1) diff --git a/quantus/metrics/faithfulness/irof.py b/quantus/metrics/faithfulness/irof.py index 2b0650cd..ce6aa63a 100644 --- a/quantus/metrics/faithfulness/irof.py +++ b/quantus/metrics/faithfulness/irof.py @@ -10,7 +10,7 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import baseline_replacement_by_mask from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( DataType, @@ -126,14 +126,12 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = baseline_replacement_by_mask # Save metric-specific attributes. self.segmentation_method = segmentation_method self.nr_channels = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -259,79 +257,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - Returns - ------- - float - The evaluation results. - """ - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - # Segment image. - segments = utils.get_superpixel_segments( - img=np.moveaxis(x, 0, -1).astype("double"), - segmentation_method=self.segmentation_method, - ) - nr_segments = len(np.unique(segments)) - asserts.assert_nr_segments(nr_segments=nr_segments) - - # Calculate average attribution of each segment. - att_segs = np.zeros(nr_segments) - for i, s in enumerate(range(nr_segments)): - att_segs[i] = np.mean(a[:, segments == s]) - - # Sort segments based on the mean attribution (descending order). - s_indices = np.argsort(-att_segs) - - preds = [] - x_prev_perturbed = x - - for i_ix, s_ix in enumerate(s_indices): - # Perturb input by indices of attributions. - a_ix = np.nonzero((segments == s_ix).flatten())[0] - - x_perturbed = self.perturb_func( - arr=x_prev_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change( - x=x_prev_perturbed, x_perturbed=x_perturbed - ) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - # Normalise the scores to be within range [0, 1]. - preds.append(float(y_pred_perturb / y_pred)) - x_prev_perturbed = x_perturbed - - # Calculate the area over the curve (AOC) score. - aoc = len(preds) - utils.calculate_auc(np.array(preds)) - return aoc - def custom_preprocess( self, x_batch: np.ndarray, @@ -397,7 +322,68 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + + # Predict on input. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + # Segment image. + segments_batch = [] + s_indices_batch_list = [] + for x, a in zip(x_batch, a_batch): + segments = utils.get_superpixel_segments( + img=np.moveaxis(x, 0, -1).astype("double"), + segmentation_method=self.segmentation_method, + ) + nr_segments = len(np.unique(segments)) + asserts.assert_nr_segments(nr_segments=nr_segments) + segments_batch.append(segments) + + # Calculate average attribution of each segment. + att_segs = np.zeros(nr_segments) + for i, s in enumerate(range(nr_segments)): + att_segs[i] = np.mean(a[:, segments == s]) + + # Sort segments based on the mean attribution (descending order). + s_indices = np.argsort(-att_segs) + s_indices_batch_list.append(s_indices) + segments_batch = np.stack(segments_batch, axis=0) + max_segments_len = max([len(s_indices) for s_indices in s_indices_batch_list]) + mask_preds_batch = np.array( + [[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list] + ) + s_indices_batch = np.array( + [s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list] + ) + + preds = [] + x_perturbed = x_batch.copy() + for s_indices_segment in s_indices_batch.T: + # Perturb input by indices of attributions. + mask = (segments_batch == s_indices_segment[:, None, None])[:, None] + + x_new_perturbed = self.perturb_func( + arr=x_perturbed, + mask=mask, + ) + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_new_perturbed, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_new_perturbed, x_new_perturbed.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + # Normalise the scores to be within range [0, 1]. + preds.append(y_pred_perturb / y_pred) + x_perturbed = x_new_perturbed + preds = np.stack(preds, axis=1) * mask_preds_batch + # Calculate the area over the curve (AOC) score. + aoc = mask_preds_batch.sum(-1) - utils.calculate_auc(preds, batched=True) + return aoc diff --git a/quantus/metrics/faithfulness/monotonicity.py b/quantus/metrics/faithfulness/monotonicity.py index 0f3ca879..d7eaf440 100644 --- a/quantus/metrics/faithfulness/monotonicity.py +++ b/quantus/metrics/faithfulness/monotonicity.py @@ -9,8 +9,9 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.helpers import asserts, utils, warn from quantus.helpers.enums import ( DataType, @@ -127,7 +128,7 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.features_in_step = features_in_step @@ -255,67 +256,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Prepare shapes. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - - # Copy the input x but fill with baseline values. - baseline_value = utils.get_baseline_value( - value=self.perturb_func.keywords["perturb_baseline"], # type: ignore - arr=x, - return_shape=x.shape, # TODO. Double-check this over using = (1,). - ) - x_baseline = np.full(x.shape, baseline_value) - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_baseline = self.perturb_func( - arr=x_baseline, - indices=a_ix, - indexed_axes=self.a_axes, - ) - - # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). - x_input = model.shape_input(x_baseline, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - return np.all(np.diff(preds) >= 0) - def custom_preprocess( self, x_batch: np.ndarray, @@ -371,7 +311,54 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a_batch, axis=1) + + n_perturbations = math.ceil(n_features / self.features_in_step) + preds = [] + + # Copy the input x but fill with baseline values. + baseline_value = utils.get_baseline_value( + value=self.perturb_func.keywords["perturb_baseline"], # type: ignore + arr=x_batch.reshape(batch_size, -1), + return_shape=( + batch_size, + n_features, + ), + batched=True, + ) + x_baseline = np.full((batch_size, n_features), baseline_value).reshape( + *x_batch.shape + ) + + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index + * self.features_in_step : (perturbation_step_index + 1) + * self.features_in_step, + ] + x_baseline = self.perturb_func( + arr=x_baseline.reshape(batch_size, -1), + indices=a_ix, + ) + x_baseline = x_baseline.reshape(*x_batch.shape) + + # Predict on perturbed input x (that was initially filled with a constant 'perturb_baseline' value). + x_input = model.shape_input( + x_baseline, x_batch.shape, channel_first=True, batched=True + ) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + preds.append(y_pred_perturb) + preds = np.stack(preds, axis=1) + + return np.all(np.diff(preds) >= 0, axis=1).tolist() diff --git a/quantus/metrics/faithfulness/monotonicity_correlation.py b/quantus/metrics/faithfulness/monotonicity_correlation.py index 7efddf17..ee011e35 100644 --- a/quantus/metrics/faithfulness/monotonicity_correlation.py +++ b/quantus/metrics/faithfulness/monotonicity_correlation.py @@ -10,8 +10,9 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.functions.similarity_func import correlation_spearman from quantus.helpers import asserts, warn from quantus.helpers.enums import ( @@ -141,16 +142,14 @@ def __init__( similarity_func = correlation_spearman if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices self.similarity_func = similarity_func self.eps = eps self.nr_samples = nr_samples self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -276,77 +275,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - float - The evaluation results. - """ - # Predict on input x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - inv_pred = 1.0 if np.abs(y_pred) < self.eps else 1.0 / np.abs(y_pred) - inv_pred = inv_pred ** 2 - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (ascending). - a_indices = np.argsort(a) - - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - atts = [None for _ in range(n_perturbations)] - vars = [None for _ in range(n_perturbations)] - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - - y_pred_perturbs = [] - - for s_ix in range(self.nr_samples): - x_perturbed = self.perturb_func( - arr=x, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - y_pred_perturbs.append(y_pred_perturb) - - vars[i_ix] = float( - np.mean((np.array(y_pred_perturbs) - np.array(y_pred)) ** 2) * inv_pred - ) - atts[i_ix] = float(sum(a[a_ix])) - - return self.similarity_func(a=atts, b=vars) - def custom_preprocess( self, x_batch: np.ndarray, @@ -403,13 +331,59 @@ def evaluate_batch( The evaluation results. """ - # Evaluate explanations. - return [ - self.evaluate_instance( - model=model, - x=x, - y=y, - a=a, - ) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Predict on input x. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + batch_size = x_batch.shape[0] + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + inv_pred = np.ones(batch_size) + inv_pred[np.abs(y_pred) >= self.eps] = 1.0 / np.abs(y_pred) + inv_pred = inv_pred**2 + + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Reshape attributions. + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (ascending). + a_indices = np.argsort(a_batch, axis=1) + + n_perturbations = math.ceil(n_features / self.features_in_step) + atts = [] + vars = [] + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, + ] + + y_pred_perturbs = [] + for _ in range(self.nr_samples): + x_perturbed = self.perturb_func( + arr=x_batch.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + y_pred_perturbs.append(y_pred_perturb) + y_pred_perturbs = np.stack(y_pred_perturbs, axis=1) + + vars.append(np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred) + atts.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=1)) + vars = np.stack(vars, axis=1) + atts = np.stack(atts, axis=1) + + similarities: np.array = self.similarity_func(a=atts, b=vars, batched=True) + return similarities.tolist() diff --git a/quantus/metrics/faithfulness/pixel_flipping.py b/quantus/metrics/faithfulness/pixel_flipping.py index a081b6ba..20330f11 100644 --- a/quantus/metrics/faithfulness/pixel_flipping.py +++ b/quantus/metrics/faithfulness/pixel_flipping.py @@ -9,8 +9,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np +import math -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import batch_baseline_replacement_by_indices from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -128,14 +129,12 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.features_in_step = features_in_step self.return_auc_per_sample = return_auc_per_sample - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -255,66 +254,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> Union[float, List[float]]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - list - The evaluation results. - """ - - # Reshape attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Prepare lists. - n_perturbations = len(range(0, len(a_indices), self.features_in_step)) - preds = [None for _ in range(n_perturbations)] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - preds[i_ix] = y_pred_perturb - - if self.return_auc_per_sample: - return float(utils.calculate_auc(preds)) - - return preds - def custom_preprocess( self, x_batch: np.ndarray, @@ -343,9 +282,7 @@ def custom_preprocess( @property def get_auc_score(self): """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] - ) + return np.mean([utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores]) def evaluate_batch( self, @@ -377,7 +314,45 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Prepare lists. + n_perturbations = math.ceil(n_features / self.features_in_step) + preds = [] + x_perturbed = x_batch.copy() + x_batch_shape = x_batch.shape + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_perturbed.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + preds.append(y_pred_perturb) + + if self.return_auc_per_sample: + return utils.calculate_auc(np.stack(preds, axis=1), batched=True).tolist() + + return np.stack(preds, axis=1).tolist() diff --git a/quantus/metrics/faithfulness/region_perturbation.py b/quantus/metrics/faithfulness/region_perturbation.py index 7a7370e2..faec84a8 100644 --- a/quantus/metrics/faithfulness/region_perturbation.py +++ b/quantus/metrics/faithfulness/region_perturbation.py @@ -12,7 +12,9 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.helpers import asserts, plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -140,15 +142,13 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.patch_size = patch_size self.order = order.lower() self.regions_evaluation = regions_evaluation - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. asserts.assert_attributions_order(order=self.order) @@ -276,178 +276,137 @@ def __call__( **kwargs, ) - def evaluate_instance( + @property + def get_auc_score(self): + """Calculate the area under the curve (AUC) score for several test samples.""" + return np.mean([utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores]) + + def evaluate_batch( self, model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: + x_batch: np.ndarray, + y_batch: np.ndarray, + a_batch: np.ndarray, + **kwargs, + ) -> List[List[float]]: """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. + This method performs XAI evaluation on a single batch of explanations. + For more information on the specific logic, we refer the metric’s initialisation docstring. Parameters ---------- model: ModelInterface A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - s: np.ndarray - The segmentation to be evaluated on an instance-basis. + x_batch: np.ndarray + The input to be evaluated on a batch-basis. + y_batch: np.ndarray + The output to be evaluated on a batch-basis. + a_batch: np.ndarray + The explanation to be evaluated on a batch-basis. + kwargs: + Unused. Returns ------- - : list + scores_batch: The evaluation results. """ + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + x_batch_shape = x_batch.shape + if len(x_batch.shape) == 3: + x_batch = x_batch[:, :, None] + a_batch = a_batch[:, :, None] + + batch_size = a_batch.shape[0] # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) + x_input = model.shape_input(x_batch, x_batch_shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] - patches = [] - x_perturbed = x.copy() + x_perturbed = x_batch.copy() # Pad input and attributions. This is needed to allow for any patch_size. - pad_width = self.patch_size - 1 - x_pad = utils._pad_array(x, pad_width, mode="constant", padded_axes=self.a_axes) - a_pad = utils._pad_array(a, pad_width, mode="constant", padded_axes=self.a_axes) + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, self.patch_size), utils.get_padding_size( + x_perturbed_w, self.patch_size + ) + padding = ((0, 0), (0, 0), padding_h, padding_w) + x_pad = utils._pad_array( + x_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + a_pad = utils._pad_array( + a_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) # Create patches across whole input shape and aggregate attributions. - att_sums = [] - axis_iterators = [ - range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes - ] - for top_left_coords in itertools.product(*axis_iterators): + att_sums_list = [] + patches_list = [] + for block_indices in utils.get_block_indices(x_pad, self.patch_size): # Create slice for patch. - patch_slice = utils.create_patch_slice( - patch_size=self.patch_size, - coords=top_left_coords, - ) + a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) # Sum attributions for patch. - att_sums.append( - a_pad[utils.expand_indices(a_pad, patch_slice, self.a_axes)].sum() - ) - patches.append(patch_slice) + att_sums_list.append(a_sum) + patches_list.append(block_indices) + att_sums = np.stack(att_sums_list, -1) + patches = np.stack(patches_list, 1) if self.order == "random": # Order attributions randomly. - order = np.arange(len(patches)) - np.random.shuffle(order) + order = np.array([np.random.permutation(patches.shape[1]) for _ in range(batch_size)]) elif self.order == "morf": # Order attributions according to the most relevant first. - order = np.argsort(att_sums)[::-1] + order = np.argsort(-att_sums, -1) elif self.order == "lerf": # Order attributions according to the least relevant first. - order = np.argsort(att_sums) + order = np.argsort(att_sums, -1) else: - raise ValueError( - "Chosen order must be in ['random', 'morf', 'lerf'] but is: {self.order}." - ) + raise ValueError("Chosen order must be in ['random', 'morf', 'lerf'] but is: {self.order}.") # Create ordered list of patches. - ordered_patches = [patches[p] for p in order] - - # Remove overlapping patches - blocked_mask = np.zeros(x_pad.shape, dtype=bool) - ordered_patches_no_overlap = [] - for patch_slice in ordered_patches: - patch_mask = np.zeros(x_pad.shape, dtype=bool) - patch_mask[ - utils.expand_indices(patch_mask, patch_slice, self.a_axes) - ] = True - # patch_mask_exp = utils.expand_indices(patch_mask, patch_slice, self.a_axes) - # patch_mask[patch_mask_exp] = True - intersected = blocked_mask & patch_mask - - if not intersected.any(): - ordered_patches_no_overlap.append(patch_slice) - blocked_mask = blocked_mask | patch_mask - - if len(ordered_patches_no_overlap) >= self.regions_evaluation: - break - - # Warn - warn.warn_iterations_exceed_patch_number( - self.regions_evaluation, len(ordered_patches_no_overlap) - ) + ordered_patches = patches[np.arange(batch_size)[:, None], order].transpose(1, 0, 2) # Increasingly perturb the input and store the decrease in function value. - results = [None for _ in range(len(ordered_patches_no_overlap))] - for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should probably depend on the used perturb_func? - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - + results = [] + x_perturbed_pad = utils._pad_array( + x_perturbed, + padding, + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for patch_slice in ordered_patches[: self.regions_evaluation]: # Perturb. - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) + x_perturbed_pad = self.perturb_func(arr=x_perturbed_pad.reshape(batch_size, -1), indices=patch_slice) # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - results[patch_id] = y_pred - y_pred_perturb - + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch_shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + results.append(y_pred - y_pred_perturb) + results = np.stack(results, 1) return results - - @property - def get_auc_score(self): - """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [utils.calculate_auc(np.array(curve)) for curve in self.evaluation_scores] - ) - - def evaluate_batch( - self, - model: ModelInterface, - x_batch: np.ndarray, - y_batch: np.ndarray, - a_batch: np.ndarray, - **kwargs, - ) -> List[List[float]]: - """ - This method performs XAI evaluation on a single batch of explanations. - For more information on the specific logic, we refer the metric’s initialisation docstring. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x_batch: np.ndarray - The input to be evaluated on a batch-basis. - y_batch: np.ndarray - The output to be evaluated on a batch-basis. - a_batch: np.ndarray - The explanation to be evaluated on a batch-basis. - kwargs: - Unused. - - Returns - ------- - scores_batch: - The evaluation results. - """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] diff --git a/quantus/metrics/faithfulness/road.py b/quantus/metrics/faithfulness/road.py index 5e25f0dd..bafaf1e2 100644 --- a/quantus/metrics/faithfulness/road.py +++ b/quantus/metrics/faithfulness/road.py @@ -132,9 +132,7 @@ def __init__( self.percentages = percentages self.a_size = None - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, noise=noise - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, noise=noise) # Asserts and warnings. if not self.disable_warnings: @@ -256,57 +254,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - list: - The evaluation results. - """ - # Order indices. - ordered_indices = np.argsort(a, axis=None)[::-1] - - results_instance = np.array([None for _ in self.percentages]) - - for p_ix, p in enumerate(self.percentages): - top_k_indices = ordered_indices[: int(self.a_size * p / 100)] - - x_perturbed = self.perturb_func( # type: ignore - arr=x, - indices=top_k_indices, - ) - - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - class_pred_perturb = np.argmax(model.predict(x_input)) - - # Write a boolean into the percentage results. - results_instance[p_ix] = int(y == class_pred_perturb) - - # Return list of booleans for each percentage. - return results_instance - def custom_batch_preprocess(self, a_batch: np.ndarray, **kwargs) -> None: """ROAD requires `a_size` property to be set to `image_height` * `image_width` of an explanation.""" if self.a_size is None: @@ -330,9 +277,10 @@ def custom_postprocess( """ # Calculate accuracy for every number of most important pixels removed. + percentage_scores = self.evaluation_scores[0].mean(axis=0) self.evaluation_scores = { - percentage: np.mean(np.array(self.evaluation_scores)[:, p_ix]) - for p_ix, percentage in enumerate(self.percentages) + percentage: percentage_score + for p_ix, (percentage, percentage_score) in enumerate(zip(self.percentages, percentage_scores)) } def evaluate_batch( @@ -365,7 +313,37 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + ordered_indices = np.argsort(-a_batch, axis=1) + results = [] + for p in self.percentages: + top_k_indices = ordered_indices[:, : int(self.a_size * p / 100)] + + x_perturbed = [] + for x_element, top_k_index in zip(x_batch, top_k_indices): + x_perturbed_element = self.perturb_func( # type: ignore + arr=x_element, + indices=top_k_index, + ) + x_perturbed.append(x_perturbed_element) + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed) + x_perturbed = np.stack(x_perturbed, axis=0) + + # Predict on perturbed input x and store the difference from predicting on unperturbed input. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + class_pred_perturb = np.argmax(model.predict(x_input), axis=-1) + + # Write a boolean into the percentage results. + results.append(y_batch == class_pred_perturb) + results = np.stack(results, axis=1).astype(int) + # Return list of booleans for each percentage. + return [results] diff --git a/quantus/metrics/faithfulness/selectivity.py b/quantus/metrics/faithfulness/selectivity.py index 2058d05e..d9197cbd 100644 --- a/quantus/metrics/faithfulness/selectivity.py +++ b/quantus/metrics/faithfulness/selectivity.py @@ -12,7 +12,7 @@ import numpy as np -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import baseline_replacement_by_indices, batch_baseline_replacement_by_indices from quantus.helpers import plotting, utils, warn from quantus.helpers.enums import ( DataType, @@ -134,22 +134,17 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. self.patch_size = patch_size - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: warn.warn_parameterisation( metric_name=self.__class__.__name__, - sensitive_params=( - "baseline value 'perturb_baseline' and the patch size for masking" - " 'patch_size'" - ), + sensitive_params=("baseline value 'perturb_baseline' and the patch size for masking" " 'patch_size'"), data_domain_applicability=( f"Also, the current implementation only works for 3-dimensional (image) data." ), @@ -266,117 +261,10 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> List[float]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - : list - The evaluation results. - """ - - # Predict on input. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - patches = [] - x_perturbed = x.copy() - - # Pad input and attributions. This is needed to allow for any patch_size. - pad_width = self.patch_size - 1 - x_pad = utils._pad_array(x, pad_width, mode="constant", padded_axes=self.a_axes) - a_pad = utils._pad_array(a, pad_width, mode="constant", padded_axes=self.a_axes) - - # Get patch indices of sorted attributions (descending). - att_sums = [] - axis_iterators = [ - range(pad_width, x_pad.shape[axis] - pad_width) for axis in self.a_axes - ] - for top_left_coords in itertools.product(*axis_iterators): - # Create slice for patch. - patch_slice = utils.create_patch_slice( - patch_size=self.patch_size, - coords=top_left_coords, - ) - - # Sum attributions for patch. - att_sums.append( - a_pad[utils.expand_indices(a_pad, patch_slice, self.a_axes)].sum() - ) - patches.append(patch_slice) - - # Create ordered list of patches. - ordered_patches = [patches[p] for p in np.argsort(att_sums)[::-1]] - - # Remove overlapping patches. - blocked_mask = np.zeros(x_pad.shape, dtype=bool) - ordered_patches_no_overlap = [] - for patch_slice in ordered_patches: - patch_mask = np.zeros(x_pad.shape, dtype=bool) - patch_mask_exp = utils.expand_indices(patch_mask, patch_slice, self.a_axes) - patch_mask[patch_mask_exp] = True - intersected = blocked_mask & patch_mask - - if not intersected.any(): - ordered_patches_no_overlap.append(patch_slice) - blocked_mask = blocked_mask | patch_mask - - # Increasingly perturb the input and store the decrease in function value. - results = np.array([None for _ in range(len(ordered_patches_no_overlap))]) - for patch_id, patch_slice in enumerate(ordered_patches_no_overlap): - # Pad x_perturbed. The mode should depend on the used perturb_func. - x_perturbed_pad = utils._pad_array( - x_perturbed, pad_width, mode="edge", padded_axes=self.a_axes - ) - - # Perturb the input. - x_perturbed_pad = self.perturb_func( - arr=x_perturbed_pad, - indices=patch_slice, - indexed_axes=self.a_axes, - ) - - # Remove padding. - x_perturbed = utils._unpad_array( - x_perturbed_pad, pad_width, padded_axes=self.a_axes - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Predict on perturbed input x and store the difference from predicting on unperturbed input. - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - - results[patch_id] = y_pred_perturb - - return results - @property def get_auc_score(self): """Calculate the area under the curve (AUC) score for several test samples.""" - return np.mean( - [ - utils.calculate_auc(np.array(curve)) - for i, curve in enumerate(self.evaluation_scores) - ] - ) + return np.mean([utils.calculate_auc(np.array(curve)) for i, curve in enumerate(self.evaluation_scores)]) def evaluate_batch( self, @@ -408,7 +296,85 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + x_batch_shape = x_batch.shape + if len(x_batch.shape) == 3: + x_batch = x_batch[:, :, None] + a_batch = a_batch[:, :, None] + + batch_size = a_batch.shape[0] + + x_perturbed = x_batch.copy() + + # Pad input and attributions. This is needed to allow for any patch_size. + x_perturbed_h, x_perturbed_w = x_perturbed.shape[-2:] + padding_h, padding_w = utils.get_padding_size(x_perturbed_h, self.patch_size), utils.get_padding_size( + x_perturbed_w, self.patch_size + ) + padding = ((0, 0), (0, 0), padding_h, padding_w) + x_pad = utils._pad_array( + x_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + a_pad = utils._pad_array( + a_batch, + padding, + mode="constant", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + + # Get patch indices of sorted attributions (descending). + att_sums_list = [] + patches_list = [] + for block_indices in utils.get_block_indices(x_pad, self.patch_size): + # Create slice for patch. + a_sum = a_pad.reshape(batch_size, -1)[np.arange(batch_size)[:, None], block_indices].sum(axis=-1) + + # Sum attributions for patch. + att_sums_list.append(a_sum) + patches_list.append(block_indices) + att_sums = np.stack(att_sums_list, -1) + patches = np.stack(patches_list, 1) + + # Create ordered list of patches. + order = np.argsort(-att_sums, -1) + ordered_patches = patches[np.arange(batch_size)[:, None], order].transpose(1, 0, 2) + + # Increasingly perturb the input and store the decrease in function value. + results = [] + x_perturbed_pad = utils._pad_array( + x_perturbed, + padding, + mode="edge", + padded_axes=np.arange(len(x_perturbed.shape)), + ) + x_perturbed_pad_shape = x_perturbed_pad.shape + for patch_slice in ordered_patches: + # Perturb. + x_perturbed_pad = self.perturb_func(arr=x_perturbed_pad.reshape(batch_size, -1), indices=patch_slice) + + # Remove padding. + x_perturbed_pad = x_perturbed_pad.reshape(*x_perturbed_pad_shape) + x_perturbed = x_perturbed_pad[ + :, + :, + padding_h[0] : x_perturbed_pad.shape[2] - padding_h[1], + padding_w[0] : x_perturbed_pad.shape[3] - padding_w[1], + ] + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch_shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + + results.append(y_pred_perturb) + results = np.stack(results, 1, dtype=np.float64) + + return results diff --git a/quantus/metrics/faithfulness/sensitivity_n.py b/quantus/metrics/faithfulness/sensitivity_n.py index 6abafee8..363645d0 100644 --- a/quantus/metrics/faithfulness/sensitivity_n.py +++ b/quantus/metrics/faithfulness/sensitivity_n.py @@ -10,9 +10,12 @@ from typing import Any, Callable, Dict, List, Optional import numpy as np +import math from quantus.functions.normalise_func import normalise_by_max -from quantus.functions.perturb_func import baseline_replacement_by_indices +from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, +) from quantus.functions.similarity_func import correlation_pearson from quantus.helpers import asserts, plotting, warn from quantus.helpers.enums import ( @@ -141,7 +144,7 @@ def __init__( ) if perturb_func is None: - perturb_func = baseline_replacement_by_indices + perturb_func = batch_baseline_replacement_by_indices # Save metric-specific attributes. if similarity_func is None: @@ -149,9 +152,7 @@ def __init__( self.similarity_func = similarity_func self.n_max_percentage = n_max_percentage self.features_in_step = features_in_step - self.perturb_func = make_perturb_func( - perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline - ) + self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline) # Asserts and warnings. if not self.disable_warnings: @@ -275,69 +276,6 @@ def __call__( **kwargs, ) - def evaluate_instance( - self, - model: ModelInterface, - x: np.ndarray, - y: np.ndarray, - a: np.ndarray, - ) -> Dict[str, List[float]]: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - model: ModelInterface - A ModelInteface that is subject to explanation. - x: np.ndarray - The input to be evaluated on an instance-basis. - y: np.ndarray - The output to be evaluated on an instance-basis. - a: np.ndarray - The explanation to be evaluated on an instance-basis. - - Returns - ------- - (Dict[str, List[float]]): The evaluation results. - """ - - # Reshape the attributions. - a = a.flatten() - - # Get indices of sorted attributions (descending). - a_indices = np.argsort(-a) - - # Predict on x. - x_input = model.shape_input(x, x.shape, channel_first=True) - y_pred = float(model.predict(x_input)[:, y]) - - att_sums = [] - pred_deltas = [] - x_perturbed = x.copy() - - for i_ix, a_ix in enumerate(a_indices[:: self.features_in_step]): - # Perturb input by indices of attributions. - a_ix = a_indices[ - (self.features_in_step * i_ix) : (self.features_in_step * (i_ix + 1)) - ] - x_perturbed = self.perturb_func( - arr=x_perturbed, - indices=a_ix, - indexed_axes=self.a_axes, - ) - warn.warn_perturbation_caused_no_change(x=x, x_perturbed=x_perturbed) - - # Sum attributions. - att_sums.append(float(a[a_ix].sum())) - - x_input = model.shape_input(x_perturbed, x.shape, channel_first=True) - y_pred_perturb = float(model.predict(x_input)[:, y]) - pred_deltas.append(y_pred - y_pred_perturb) - - # Each list-element of self.evaluation_scores will be such a dictionary - # We will unpack that later in custom_postprocess(). - return {"att_sums": att_sums, "pred_deltas": pred_deltas} - def custom_preprocess( self, x_batch: np.ndarray, @@ -382,35 +320,18 @@ def custom_postprocess( ------- None """ - max_features = int( - self.n_max_percentage * np.prod(x_batch.shape[2:]) // self.features_in_step - ) + max_features = int(self.n_max_percentage * np.prod(x_batch.shape[2:]) // self.features_in_step) # Get pred_deltas and att_sums from result list. - sub_results_pred_deltas: List[Any] = [ - r["pred_deltas"] for r in self.evaluation_scores - ] - sub_results_att_sums: List[Any] = [ - r["att_sums"] for r in self.evaluation_scores - ] - - # Re-arrange sub-lists so that they are sorted by n. - sub_results_pred_deltas_l: Dict[int, Any] = {k: [] for k in range(max_features)} - sub_results_att_sums_l: Dict[int, Any] = {k: [] for k in range(max_features)} - - for k in range(max_features): - for pred_deltas_instance in sub_results_pred_deltas: - sub_results_pred_deltas_l[k].append(pred_deltas_instance[k]) - for att_sums_instance in sub_results_att_sums: - sub_results_att_sums_l[k].append(att_sums_instance[k]) + pred_deltas: np.array = self.evaluation_scores[0]["pred_deltas"] + att_sums: np.array = self.evaluation_scores[0]["att_sums"] # Compute the similarity for each n. - self.evaluation_scores = [ - self.similarity_func( - a=sub_results_att_sums_l[k], b=sub_results_pred_deltas_l[k] - ) - for k in range(max_features) - ] + self.evaluation_scores = self.similarity_func( + a=pred_deltas[:, :max_features].T, + b=att_sums[:, :max_features].T, + batched=True, + ) def evaluate_batch( self, @@ -442,7 +363,53 @@ def evaluate_batch( scores_batch: The evaluation results. """ - return [ - self.evaluate_instance(model=model, x=x, y=y, a=a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + # Prepare shapes. Expand a_batch if not the same shape + if x_batch.shape != a_batch.shape: + a_batch = np.broadcast_to(a_batch, x_batch.shape) + + # Flatten the attributions. + batch_size = a_batch.shape[0] + a_batch = a_batch.reshape(batch_size, -1) + n_features = a_batch.shape[-1] + + # Get indices of sorted attributions (descending). + a_indices = np.argsort(-a_batch, axis=1) + + # Predict on x. + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) + y_pred = model.predict(x_input)[np.arange(batch_size), y_batch] + + n_perturbations = math.ceil(n_features / self.features_in_step) + pred_deltas = [] + att_sums = [] + x_batch_shape = x_batch.shape + x_perturbed = x_batch.copy() + for perturbation_step_index in range(n_perturbations): + # Perturb input by indices of attributions. + a_ix = a_indices[ + :, + perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step, + ] + x_perturbed = self.perturb_func( + arr=x_perturbed.reshape(batch_size, -1), + indices=a_ix, + ) + x_perturbed = x_perturbed.reshape(*x_batch_shape) + + # Check if the perturbation caused change + for x_element, x_perturbed_element in zip(x_batch, x_perturbed): + warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element) + + # Sum attributions. + att_sums.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=-1)) + + # Predict on perturbed input x. + x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True) + y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch] + pred_deltas.append(y_pred - y_pred_perturb) + pred_deltas = np.stack(pred_deltas, axis=1) + att_sums = np.stack(att_sums, axis=1) + + # Each list-element of self.evaluation_scores will be such a dictionary + # We will unpack that later in custom_postprocess(). + return [{"att_sums": att_sums, "pred_deltas": pred_deltas}] diff --git a/quantus/metrics/faithfulness/sufficiency.py b/quantus/metrics/faithfulness/sufficiency.py index 4d2fb9ca..4de52c3b 100644 --- a/quantus/metrics/faithfulness/sufficiency.py +++ b/quantus/metrics/faithfulness/sufficiency.py @@ -249,40 +249,6 @@ def __call__( **kwargs, ) - @staticmethod - def evaluate_instance( - i: int, - a_sim_vector: np.ndarray, - y_pred_classes: np.ndarray, - ) -> float: - """ - Evaluate instance gets model and data for a single instance as input and returns the evaluation result. - - Parameters - ---------- - i: int - The index of the current instance. - a_sim_vector: any - The custom input to be evaluated on an instance-basis. - y_pred_classes: np,ndarray - The class predictions of the complete input dataset. - - Returns - ------- - float - The evaluation results. - """ - - # Metric logic. - pred_a = y_pred_classes[i] - low_dist_a = np.argwhere(a_sim_vector == 1.0).flatten() - low_dist_a = low_dist_a[low_dist_a != i] - pred_low_dist_a = y_pred_classes[low_dist_a] - - if len(low_dist_a) == 0: - return 0 - return np.sum(pred_low_dist_a == pred_a) / len(low_dist_a) - def custom_batch_preprocess( self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: @@ -295,9 +261,7 @@ def custom_batch_preprocess( a_sim_matrix[dist_matrix <= self.threshold] = 1 # Predict on input. - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) y_pred_classes = np.argmax(model.predict(x_input), axis=1).flatten() return { @@ -334,10 +298,11 @@ def evaluate_batch( evaluation_scores: List of measured sufficiency for each entry in the batch. """ - - return [ - self.evaluate_instance( - i=i, a_sim_vector=a_sim_vector, y_pred_classes=y_pred_classes - ) - for i, a_sim_vector in zip(i_batch, a_sim_vector_batch) - ] + batch_size = y_pred_classes.shape[0] + # Metric logic. + a_sim_vector_batch *= 1 - np.eye(batch_size, dtype=np.int32) + pred_classes_equality = y_pred_classes[None] == y_pred_classes[:, None] + low_dist_freqs = a_sim_vector_batch.sum(-1) + return (pred_classes_equality * a_sim_vector_batch).sum(-1) / np.where( + low_dist_freqs == 0, np.inf, low_dist_freqs + ) diff --git a/quantus/metrics/randomisation/random_logit.py b/quantus/metrics/randomisation/random_logit.py index cc42db9b..728a5a68 100644 --- a/quantus/metrics/randomisation/random_logit.py +++ b/quantus/metrics/randomisation/random_logit.py @@ -267,16 +267,11 @@ def evaluate_instance( """ # Randomly select off-class labels. np.random.seed(self.seed) - y_off = np.array( - [ - np.random.choice( - [y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y] - ) - ] - ) + y_off = np.array([np.random.choice([y_ for y_ in list(np.arange(0, self.num_classes)) if y_ != y])]) # Explain against a random class. a_perturbed = self.explain_batch(model, np.expand_dims(x, axis=0), y_off) - return self.similarity_func(a.flatten(), a_perturbed.flatten()) + similarity = float(self.similarity_func(a.flatten(), a_perturbed.flatten())) + return similarity def custom_preprocess( self, @@ -328,7 +323,4 @@ def evaluate_batch( scores_batch: Evaluation results. """ - return [ - self.evaluate_instance(model, x, y, a) - for x, y, a in zip(x_batch, y_batch, a_batch) - ] + return [self.evaluate_instance(model, x, y, a) for x, y, a in zip(x_batch, y_batch, a_batch)] diff --git a/quantus/metrics/robustness/consistency.py b/quantus/metrics/robustness/consistency.py index ea782d2d..d28f9b18 100644 --- a/quantus/metrics/robustness/consistency.py +++ b/quantus/metrics/robustness/consistency.py @@ -277,9 +277,7 @@ def custom_batch_preprocess( self, model: ModelInterface, x_batch: np.ndarray, a_batch: np.ndarray, **kwargs ) -> Dict[str, np.ndarray]: """Compute additional arguments required for Consistency on batch-level.""" - x_input = model.shape_input( - x_batch, x_batch[0].shape, channel_first=True, batched=True - ) + x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True) a_batch_flat = a_batch.reshape(a_batch.shape[0], -1) a_labels = np.array(list(map(self.discretise_func, a_batch_flat))) diff --git a/tests/conftest.py b/tests/conftest.py index 6e87b554..bde949ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,13 +6,17 @@ import pytest import torch from keras.datasets import cifar10 -from quantus.helpers.model.models import (CifarCNNModel, ConvNet1D, - ConvNet1DTF, LeNet, LeNetTF, - TitanicSimpleTFModel, - TitanicSimpleTorchModel) +from quantus.helpers.model.models import ( + CifarCNNModel, + ConvNet1D, + ConvNet1DTF, + LeNet, + LeNetTF, + TitanicSimpleTFModel, + TitanicSimpleTorchModel, +) from sklearn.model_selection import train_test_split -from transformers import (AutoModelForSequenceClassification, AutoTokenizer, - set_seed) +from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed CIFAR_IMAGE_SIZE = 32 MNIST_IMAGE_SIZE = 28 @@ -20,7 +24,8 @@ MINI_BATCH_SIZE = 8 RANDOM_SEED = 42 -@pytest.fixture(scope='function', autouse=True) + +@pytest.fixture(scope="function", autouse=True) def reset_prngs(): set_seed(42) @@ -29,9 +34,7 @@ def reset_prngs(): def load_mnist_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNet() - model.load_state_dict( - torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", weights_only=True)) return model @@ -58,7 +61,7 @@ def load_1d_1ch_conv_model(): model.eval() # TODO: add trained model weights # model.load_state_dict( - # torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) + # torch.load("tests/assets/mnist", map_location="cpu", weights_only=True) # ) return model @@ -70,7 +73,7 @@ def load_1d_3ch_conv_model(): model.eval() # TODO: add trained model weights # model.load_state_dict( - # torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) + # torch.load("tests/assets/mnist", map_location="cpu", pweights_only=True) # ) return model @@ -89,9 +92,7 @@ def load_1d_3ch_conv_model_tf(): def load_mnist_images(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" x_batch = ( - np.loadtxt("tests/assets/mnist_x") - .astype(float) - .reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)) + np.loadtxt("tests/assets/mnist_x").astype(float).reshape((BATCH_SIZE, 1, MNIST_IMAGE_SIZE, MNIST_IMAGE_SIZE)) )[:MINI_BATCH_SIZE] y_batch = np.loadtxt("tests/assets/mnist_y").astype(int)[:MINI_BATCH_SIZE] return {"x_batch": x_batch, "y_batch": y_batch} @@ -101,11 +102,9 @@ def load_mnist_images(): def load_cifar10_images(): """Load a batch of MNIST digits: inputs and outputs to use for testing.""" (x_train, y_train), (_, _) = cifar10.load_data() - x_batch = ( - x_train[:BATCH_SIZE] - .reshape((BATCH_SIZE, 3, CIFAR_IMAGE_SIZE, CIFAR_IMAGE_SIZE)) - .astype(float) - )[:MINI_BATCH_SIZE] + x_batch = (x_train[:BATCH_SIZE].reshape((BATCH_SIZE, 3, CIFAR_IMAGE_SIZE, CIFAR_IMAGE_SIZE)).astype(float))[ + :MINI_BATCH_SIZE + ] y_batch = y_train[:BATCH_SIZE].reshape(-1).astype(int)[:MINI_BATCH_SIZE] return {"x_batch": x_batch, "y_batch": y_batch} @@ -185,7 +184,7 @@ def flat_sequence_array(): @pytest.fixture(scope="session", autouse=True) def titanic_model_torch(): model = TitanicSimpleTorchModel() - model.load_state_dict(torch.load("tests/assets/titanic_model_torch.pickle")) + model.load_state_dict(torch.load("tests/assets/titanic_model_torch.pickle", weights_only=True)) return model @@ -244,9 +243,7 @@ def load_hf_distilbert_sequence_classifier(): TODO """ DISTILBERT_BASE = "distilbert-base-uncased" - model = AutoModelForSequenceClassification.from_pretrained( - DISTILBERT_BASE, cache_dir="/tmp/" - ) + model = AutoModelForSequenceClassification.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") return model @@ -257,11 +254,11 @@ def dummy_hf_tokenizer(): """ DISTILBERT_BASE = "distilbert-base-uncased" REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog" - tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/") + tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/", clean_up_tokenization_spaces=True) return tokenizer(REFERENCE_TEXT, return_tensors="pt") @pytest.fixture(scope="session", autouse=True) def set_env(): """Set ENV var, so test outputs are not polluted by progress bars and warnings.""" - os.environ["PYTEST"] = "1" \ No newline at end of file + os.environ["PYTEST"] = "1" diff --git a/tests/functions/test_perturb_func.py b/tests/functions/test_perturb_func.py index 0f2b87b6..4d76dea7 100644 --- a/tests/functions/test_perturb_func.py +++ b/tests/functions/test_perturb_func.py @@ -30,6 +30,11 @@ def input_zeros_2d_3ch_flattened(): return np.zeros(shape=(3, 224, 224)).flatten() +@pytest.fixture +def input_batch_zeros_2d_3ch_flattened(): + return np.zeros(shape=(2, 3, 224, 224)).reshape(2, -1) + + @pytest.fixture def input_uniform_2d_3ch_flattened(): return np.random.uniform(0, 0.1, size=(3, 224, 224)).flatten() @@ -161,6 +166,36 @@ def test_baseline_replacement_by_indices( ), f"Test failed.{out}" +@pytest.mark.perturb_func +@pytest.mark.parametrize( + "data,params,expected", + [ + ( + lazy_fixture("input_batch_zeros_2d_3ch_flattened"), + { + "indices": np.array([[0, 2], [2, 5]]), + "perturb_baseline": 1.0, + }, + 1, + ), + ], +) +def test_batch_baseline_replacement_by_indices( + data: np.ndarray, params: dict, expected: Union[float, dict, bool] +): + batch_size = data.shape[0] + # Output + out = batch_baseline_replacement_by_indices(arr=data, **params) + + # Indices + indices = params["indices"] + + if isinstance(expected, (int, float)): + assert np.all( + out[np.arange(batch_size)[:, None], indices] == expected + ), f"Test failed.{out}" + + @pytest.mark.perturb_func @pytest.mark.parametrize( "data,params,expected", diff --git a/tests/functions/test_pytorch_model.py b/tests/functions/test_pytorch_model.py index cbd24045..caa64f50 100644 --- a/tests/functions/test_pytorch_model.py +++ b/tests/functions/test_pytorch_model.py @@ -130,9 +130,7 @@ def test_get_softmax_arg_model( ), ], ) -def test_predict( - data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model -): +def test_predict(data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model): load_mnist_model.eval() training = params.pop("training", False) model = PyTorchModel(load_mnist_model, **params) @@ -171,9 +169,7 @@ def test_predict( ), ], ) -def test_shape_input( - data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model -): +def test_shape_input(data: np.ndarray, params: dict, expected: Union[float, dict, bool], load_mnist_model): load_mnist_model.eval() model = PyTorchModel(load_mnist_model, channel_first=params["channel_first"]) if not params["channel_first"]: @@ -266,9 +262,7 @@ def test_add_mean_shift_to_first_layer(load_mnist_model): ( lazy_fixture("load_hf_distilbert_sequence_classifier"), { - "input_ids": torch.tensor( - [[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]] - ), + "input_ids": torch.tensor([[101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]), "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), }, False, @@ -291,15 +285,11 @@ def test_add_mean_shift_to_first_layer(load_mnist_model): ), ], ) -def test_huggingface_classifier_predict( - hf_model, data, softmax, model_kwargs, expected -): - model = PyTorchModel( - model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs - ) +def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected): + model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs) with expected: out = model.predict(x=data) - assert np.allclose(out, expected.enter_result), "Test failed." + assert np.allclose(out, expected.enter_result, atol=1e-5), "Test failed." @pytest.fixture diff --git a/tests/metrics/test_faithfulness_metrics.py b/tests/metrics/test_faithfulness_metrics.py index 775fdd58..3a72d1df 100644 --- a/tests/metrics/test_faithfulness_metrics.py +++ b/tests/metrics/test_faithfulness_metrics.py @@ -6,6 +6,7 @@ from quantus.functions.explanation_func import explain from quantus.functions.perturb_func import ( + batch_baseline_replacement_by_indices, baseline_replacement_by_indices, noisy_linear_imputation, ) @@ -39,7 +40,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -61,7 +62,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -84,7 +85,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -105,7 +106,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -126,7 +127,7 @@ lazy_fixture("load_mnist_images_tf"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "perturb_baseline": "mean", "similarity_func": correlation_spearman, @@ -148,7 +149,7 @@ lazy_fixture("load_mnist_images_tf"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "nr_runs": 10, "similarity_func": correlation_spearman, "normalise": True, @@ -169,7 +170,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -190,7 +191,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -209,7 +210,7 @@ { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -227,7 +228,7 @@ lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "nr_runs": 10, "similarity_func": correlation_spearman, @@ -291,9 +292,7 @@ def test_faithfulness_correlation( **call_params, )[0] - assert np.all( - ((scores >= expected["min"]) & (scores <= expected["max"])) - ), "Test failed." + assert np.all(((scores >= expected["min"]) & (scores <= expected["max"]))), "Test failed." @pytest.mark.faithfulness @@ -305,7 +304,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -326,7 +325,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 196, "perturb_baseline": "uniform", "normalise": True, @@ -347,7 +346,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -369,7 +368,7 @@ def test_faithfulness_correlation( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "abs": True, @@ -391,7 +390,7 @@ def test_faithfulness_correlation( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "uniform", "normalise": True, @@ -413,7 +412,7 @@ def test_faithfulness_correlation( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "uniform", "features_in_step": 10, "normalise": True, @@ -461,9 +460,7 @@ def test_faithfulness_estimate( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -597,9 +594,7 @@ def test_iterative_removal_of_features( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -611,7 +606,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "black", "normalise": True, @@ -632,7 +627,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "white", "normalise": True, @@ -654,7 +649,7 @@ def test_iterative_removal_of_features( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "mean", "normalise": True, @@ -675,7 +670,7 @@ def test_iterative_removal_of_features( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "features_in_step": 28, "perturb_baseline": "black", "normalise": True, @@ -697,7 +692,7 @@ def test_iterative_removal_of_features( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "black", "features_in_step": 10, "normalise": True, @@ -980,7 +975,7 @@ def test_monotonicity_correlation( "init": { "features_in_step": 10, "normalise": False, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "disable_warnings": True, }, @@ -1017,7 +1012,7 @@ def test_monotonicity_correlation( "init": { "features_in_step": 10, "normalise": False, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "perturb_baseline": "mean", "disable_warnings": True, }, @@ -1065,13 +1060,7 @@ def test_pixel_flipping( **call_params, ) - assert all( - [ - (s >= expected["min"] and s <= expected["max"]) - for s_list in scores - for s in s_list - ] - ), "Test failed." + assert all([(s >= expected["min"] and s <= expected["max"]) for s_list in scores for s in s_list]), "Test failed." @pytest.mark.faithfulness @@ -1133,7 +1122,7 @@ def test_pixel_flipping( "normalise": True, "order": "morf", "disable_warnings": True, - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, }, "call": { "explain_func": explain, @@ -1239,13 +1228,7 @@ def test_region_perturbation( **call_params, ) - assert all( - [ - (s >= expected["min"] and s <= expected["max"]) - for s_list in scores - for s in s_list - ] - ), "Test failed." + assert all([(s >= expected["min"] and s <= expected["max"]) for s_list in scores for s in s_list]), "Test failed." @pytest.mark.faithfulness @@ -1614,9 +1597,7 @@ def test_sensitivity_n( **call_params, ) - assert all( - ((s >= expected["min"]) & (s <= expected["max"])) for s in scores - ), "Test failed." + assert all(((s >= expected["min"]) & (s <= expected["max"])) for s in scores), "Test failed." @pytest.mark.faithfulness @@ -1628,7 +1609,7 @@ def test_sensitivity_n( lazy_fixture("load_mnist_images"), { "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "return_aggregate": False, "normalise": True, "abs": True, @@ -1651,7 +1632,7 @@ def test_sensitivity_n( { "a_batch_generate": False, "init": { - "perturb_func": baseline_replacement_by_indices, + "perturb_func": batch_baseline_replacement_by_indices, "return_aggregate": False, "normalise": True, "abs": True, diff --git a/tests/metrics/test_localisation_metrics.py b/tests/metrics/test_localisation_metrics.py index ac86b2a7..7a43aabe 100644 --- a/tests/metrics/test_localisation_metrics.py +++ b/tests/metrics/test_localisation_metrics.py @@ -324,9 +324,7 @@ def load_artificial_attribution(): def load_mnist_adaptive_lenet_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNetAdaptivePooling(input_shape=(1, 28, 28)) - model.load_state_dict( - torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", weights_only=True)) return model @@ -349,9 +347,7 @@ def load_mnist_mosaics(load_mnist_images): def load_cifar10_adaptive_lenet_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" model = LeNetAdaptivePooling(input_shape=(3, 32, 32)) - model.load_state_dict( - torch.load("tests/assets/cifar10", map_location="cpu", pickle_module=pickle) - ) + model.load_state_dict(torch.load("tests/assets/cifar10", map_location="cpu", weights_only=True)) return model @@ -917,9 +913,7 @@ def test_relevance_mass_accuracy( print(scores) if isinstance(expected, float): - assert ( - all(round(s, 2) == round(expected, 2) for s in scores) == True - ), "Test failed." + assert all(round(s, 2) == round(expected, 2) for s in scores) == True, "Test failed." elif "type" in expected: assert isinstance(scores, expected["type"]), "Test failed." else: