Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched metrics #351

10 changes: 6 additions & 4 deletions quantus/functions/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np


def mse(a: np.array, b: np.array, **kwargs) -> float:
def mse(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float:
"""
Calculate Mean Squared Error between two images (or explanations).

Expand All @@ -19,6 +19,8 @@ def mse(a: np.array, b: np.array, **kwargs) -> float:
Array to calculate MSE with.
b: np.ndarray
Array to calculate MSE with.
batched: bool
True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features
kwargs: optional
Keyword arguments.
normalise_mse: boolean
Expand All @@ -31,10 +33,10 @@ def mse(a: np.array, b: np.array, **kwargs) -> float:
"""

normalise = kwargs.get("normalise_mse", False)

axis = -1 if batched else 0
if normalise:
# Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2.
return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0)
return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=axis)
# If no need to normalise, return (a-b)^2.

return np.average(((a - b) ** 2), axis=0)
return np.average(((a - b) ** 2), axis=axis)
18 changes: 9 additions & 9 deletions quantus/functions/norm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def fro_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the Frobenius on.
The array to calculate the Frobenius on. If 2D, the array is assumed to be batched.

Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'fro_norm' receives a 1D array."
return np.linalg.norm(a)
assert a.ndim == 1 or a.ndim == 2, "Check that 'fro_norm' receives a 1D or 2D array."
return np.linalg.norm(a, axis=-1)


def l2_norm(a: np.array) -> float:
Expand All @@ -34,15 +34,15 @@ def l2_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the L2 on
The array to calculate the L2 on. If 2D, the array is assumed to be batched.

Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'l2_norm' receives a 1D array."
return np.linalg.norm(a)
assert a.ndim == 1 or a.ndim == 2, "Check that 'l2_norm' receives a 1D array."
return np.linalg.norm(a, axis=-1)


def linf_norm(a: np.array) -> float:
Expand All @@ -52,12 +52,12 @@ def linf_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the L-inf on.
The array to calculate the L-inf on. If 2D, the array is assumed to be batched.

Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'linf_norm' receives a 1D array."
return np.linalg.norm(a, ord=np.inf)
assert a.ndim == 1 or a.ndim == 2, "Check that 'linf_norm' receives a 1D or 2D array."
return np.linalg.norm(a, ord=np.inf, axis=-1)
231 changes: 215 additions & 16 deletions quantus/functions/perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def perturb_batch(
None, array
"""
if indices is not None:
assert arr.shape[0] == len(
indices
), "arr and indices need same number of batches"
assert arr.shape[0] == len(indices), "arr and indices need same number of batches"

if not inplace:
arr = arr.copy()
Expand Down Expand Up @@ -107,13 +105,108 @@ def baseline_replacement_by_indices(

arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs)

# Perturb the array.
arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes))

return arr_perturbed


def batch_baseline_replacement_by_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import numpy.typing as npt

def batch_baseline_replacement_by_indices(
    arr: np.ndarray,
    indices: np.ndarray,
    perturb_baseline: npt.ArrayLike,
    **kwargs,
) -> np.ndarray:

arr: np.array,
indices: np.array,
perturb_baseline: Union[float, int, str, np.array],
**kwargs,
) -> np.array:
"""
Replace indices in an array by a given baseline.

Parameters
----------
arr: np.ndarray
Array to be perturbed. Shape N x F, where N is batch size and F is number of features.
indices: np.ndarray
Indices of the array to perturb. Shape N x I, where N is batch size and I is the number of indices to perturb.
perturb_baseline: float, int, str, np.ndarray
The baseline values to replace arr at indices with.
kwargs: optional
Keyword arguments.

Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs
value=perturb_baseline,
arr=arr,
return_shape=tuple(indices.shape),
batched=True,
**kwargs,
)

# Perturb the array.
arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes))
arr_perturbed[np.arange(batch_size)[:, None], indices] = baseline_value

return arr_perturbed


def baseline_replacement_by_mask(
arr: np.array,
mask: np.array,
perturb_baseline: Union[float, int, str, np.array],
**kwargs,
) -> np.array:
"""
Replace indices in an array by a given baseline.

Parameters
----------
arr: np.ndarray
Array to be perturbed. Shape is arbitrary.
mask: np.ndarray
Boolean mask of the array to perturb. Shape must be the same as the arr.
perturb_baseline: float, int, str, np.ndarray
The baseline values to replace arr at indices with.
kwargs: optional
Keyword arguments.

Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert arr.shape == mask.shape, "The shape of arr must be the same as the mask shape"

arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=perturb_baseline,
arr=arr,
return_shape=tuple(arr.shape),
**kwargs,
)

# Perturb the array.
arr_perturbed = np.where(mask, baseline_value, arr_perturbed)

return arr_perturbed

Expand Down Expand Up @@ -154,9 +247,7 @@ def baseline_replacement_by_shift(
arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs
)
baseline_value = get_baseline_value(value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs)

# Shift the input.
arr_shifted = copy.copy(arr_perturbed)
Expand Down Expand Up @@ -264,6 +355,56 @@ def gaussian_noise(
return arr_perturbed


def batch_gaussian_noise(
arr: np.array,
indices: np.array,
perturb_mean: float = 0.0,
perturb_std: float = 0.01,
**kwargs,
) -> np.array:
"""
Add gaussian noise to the input at indices.

Parameters
----------
arr: np.ndarray
Array to be perturbed.
indices: int, sequence, tuple
Array-like, with a subset shape of arr.
perturb_mean (float):
The mean for gaussian noise.
perturb_std (float):
The standard deviation for gaussian noise.
kwargs: optional
Keyword arguments.

Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""
# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Sample the noise.
noise = np.random.normal(loc=perturb_mean, scale=perturb_std, size=arr.shape)

# Perturb the array.
arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[
np.arange(batch_size)[:, None], indices
]

return arr_perturbed


def uniform_noise(
arr: np.array,
indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]],
Expand Down Expand Up @@ -303,9 +444,10 @@ def uniform_noise(
if upper_bound is None:
noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape)
else:
assert upper_bound > lower_bound, (
"Parameter 'upper_bound' needs to be larger than 'lower_bound', "
"but {} <= {}".format(upper_bound, lower_bound)
assert (
upper_bound > lower_bound
), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format(
upper_bound, lower_bound
)
noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape)

Expand All @@ -315,6 +457,66 @@ def uniform_noise(
return arr_perturbed


def batch_uniform_noise(
arr: np.array,
indices: np.array,
lower_bound: float = 0.02,
upper_bound: Union[None, float] = None,
**kwargs,
) -> np.array:
"""
Add noise to the input at indices as sampled uniformly random from [-lower_bound, lower_bound].
if upper_bound is None, and [lower_bound, upper_bound] otherwise.

Parameters
----------
arr: np.ndarray
Array to be perturbed.
indices: int, sequence, tuple
Array-like, with a subset shape of arr.
lower_bound: float
The lower bound for uniform sampling.
upper_bound: float, optional
The upper bound for uniform sampling.
kwargs: optional
Keyword arguments.

Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Sample the noise.
if upper_bound is None:
noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape)
else:
assert (
upper_bound > lower_bound
), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format(
upper_bound, lower_bound
)
noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape)

# Perturb the array.
arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[
np.arange(batch_size)[:, None], indices
]

return arr_perturbed


def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array:
"""
Rotate array by some given angle, assumes image type data and channel first layout.
Expand All @@ -336,8 +538,7 @@ def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array:

if arr.ndim != 3:
raise ValueError(
"perturb func 'rotation' requires image-type data."
"Check that this perturb_func receives a 3D array."
"perturb func 'rotation' requires image-type data." "Check that this perturb_func receives a 3D array."
)

matrix = cv2.getRotationMatrix2D(
Expand Down Expand Up @@ -523,9 +724,7 @@ def noisy_linear_imputation(
a[out_off_coord_ids, variable_ids] = weight

# Reduce weight for invalid coordinates.
sum_neighbors[np.argwhere(valid == 0).flatten()] = (
sum_neighbors[np.argwhere(valid == 0).flatten()] - weight
)
sum_neighbors[np.argwhere(valid == 0).flatten()] = sum_neighbors[np.argwhere(valid == 0).flatten()] - weight

a[np.arange(len(indices)), np.arange(len(indices))] = -sum_neighbors

Expand Down
Loading
Loading