Skip to content

Commit

Permalink
Merge pull request #351 from davor10105/batched-metrics
Browse files Browse the repository at this point in the history
Batched metrics
  • Loading branch information
annahedstroem authored Nov 9, 2024
2 parents 5e227aa + 0442047 commit 53f97c2
Show file tree
Hide file tree
Showing 27 changed files with 1,243 additions and 1,352 deletions.
10 changes: 6 additions & 4 deletions quantus/functions/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np


def mse(a: np.array, b: np.array, **kwargs) -> float:
def mse(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float:
"""
Calculate Mean Squared Error between two images (or explanations).
Expand All @@ -19,6 +19,8 @@ def mse(a: np.array, b: np.array, **kwargs) -> float:
Array to calculate MSE with.
b: np.ndarray
Array to calculate MSE with.
batched: bool
True if arrays are batched. Arrays are expected to be 2D (B x F), where B is batch size and F is the number of features
kwargs: optional
Keyword arguments.
normalise_mse: boolean
Expand All @@ -31,10 +33,10 @@ def mse(a: np.array, b: np.array, **kwargs) -> float:
"""

normalise = kwargs.get("normalise_mse", False)

axis = -1 if batched else 0
if normalise:
# Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2.
return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0)
return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=axis)
# If no need to normalise, return (a-b)^2.

return np.average(((a - b) ** 2), axis=0)
return np.average(((a - b) ** 2), axis=axis)
18 changes: 9 additions & 9 deletions quantus/functions/norm_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def fro_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the Frobenius on.
The array to calculate the Frobenius on. If 2D, the array is assumed to be batched.
Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'fro_norm' receives a 1D array."
return np.linalg.norm(a)
assert a.ndim == 1 or a.ndim == 2, "Check that 'fro_norm' receives a 1D or 2D array."
return np.linalg.norm(a, axis=-1)


def l2_norm(a: np.array) -> float:
Expand All @@ -34,15 +34,15 @@ def l2_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the L2 on
The array to calculate the L2 on. If 2D, the array is assumed to be batched.
Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'l2_norm' receives a 1D array."
return np.linalg.norm(a)
assert a.ndim == 1 or a.ndim == 2, "Check that 'l2_norm' receives a 1D array."
return np.linalg.norm(a, axis=-1)


def linf_norm(a: np.array) -> float:
Expand All @@ -52,12 +52,12 @@ def linf_norm(a: np.array) -> float:
Parameters
----------
a: np.ndarray
The array to calculate the L-inf on.
The array to calculate the L-inf on. If 2D, the array is assumed to be batched.
Returns
-------
float
The norm.
"""
assert a.ndim == 1, "Check that 'linf_norm' receives a 1D array."
return np.linalg.norm(a, ord=np.inf)
assert a.ndim == 1 or a.ndim == 2, "Check that 'linf_norm' receives a 1D or 2D array."
return np.linalg.norm(a, ord=np.inf, axis=-1)
231 changes: 215 additions & 16 deletions quantus/functions/perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def perturb_batch(
None, array
"""
if indices is not None:
assert arr.shape[0] == len(
indices
), "arr and indices need same number of batches"
assert arr.shape[0] == len(indices), "arr and indices need same number of batches"

if not inplace:
arr = arr.copy()
Expand Down Expand Up @@ -107,13 +105,108 @@ def baseline_replacement_by_indices(

arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs)

# Perturb the array.
arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes))

return arr_perturbed


def batch_baseline_replacement_by_indices(
arr: np.array,
indices: np.array,
perturb_baseline: Union[float, int, str, np.array],
**kwargs,
) -> np.array:
"""
Replace indices in an array by a given baseline.
Parameters
----------
arr: np.ndarray
Array to be perturbed. Shape N x F, where N is batch size and F is number of features.
indices: np.ndarray
Indices of the array to perturb. Shape N x I, where N is batch size and I is the number of indices to perturb.
perturb_baseline: float, int, str, np.ndarray
The baseline values to replace arr at indices with.
kwargs: optional
Keyword arguments.
Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=perturb_baseline, arr=arr, return_shape=tuple(baseline_shape), **kwargs
value=perturb_baseline,
arr=arr,
return_shape=tuple(indices.shape),
batched=True,
**kwargs,
)

# Perturb the array.
arr_perturbed[indices] = np.expand_dims(baseline_value, axis=tuple(indexed_axes))
arr_perturbed[np.arange(batch_size)[:, None], indices] = baseline_value

return arr_perturbed


def baseline_replacement_by_mask(
arr: np.array,
mask: np.array,
perturb_baseline: Union[float, int, str, np.array],
**kwargs,
) -> np.array:
"""
Replace indices in an array by a given baseline.
Parameters
----------
arr: np.ndarray
Array to be perturbed. Shape is arbitrary.
mask: np.ndarray
Boolean mask of the array to perturb. Shape must be the same as the arr.
perturb_baseline: float, int, str, np.ndarray
The baseline values to replace arr at indices with.
kwargs: optional
Keyword arguments.
Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert arr.shape == mask.shape, "The shape of arr must be the same as the mask shape"

arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=perturb_baseline,
arr=arr,
return_shape=tuple(arr.shape),
**kwargs,
)

# Perturb the array.
arr_perturbed = np.where(mask, baseline_value, arr_perturbed)

return arr_perturbed

Expand Down Expand Up @@ -154,9 +247,7 @@ def baseline_replacement_by_shift(
arr_perturbed = copy.copy(arr)

# Get the baseline value.
baseline_value = get_baseline_value(
value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs
)
baseline_value = get_baseline_value(value=input_shift, arr=arr, return_shape=tuple(baseline_shape), **kwargs)

# Shift the input.
arr_shifted = copy.copy(arr_perturbed)
Expand Down Expand Up @@ -264,6 +355,56 @@ def gaussian_noise(
return arr_perturbed


def batch_gaussian_noise(
arr: np.array,
indices: np.array,
perturb_mean: float = 0.0,
perturb_std: float = 0.01,
**kwargs,
) -> np.array:
"""
Add gaussian noise to the input at indices.
Parameters
----------
arr: np.ndarray
Array to be perturbed.
indices: int, sequence, tuple
Array-like, with a subset shape of arr.
perturb_mean (float):
The mean for gaussian noise.
perturb_std (float):
The standard deviation for gaussian noise.
kwargs: optional
Keyword arguments.
Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""
# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Sample the noise.
noise = np.random.normal(loc=perturb_mean, scale=perturb_std, size=arr.shape)

# Perturb the array.
arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[
np.arange(batch_size)[:, None], indices
]

return arr_perturbed


def uniform_noise(
arr: np.array,
indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]],
Expand Down Expand Up @@ -303,9 +444,10 @@ def uniform_noise(
if upper_bound is None:
noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape)
else:
assert upper_bound > lower_bound, (
"Parameter 'upper_bound' needs to be larger than 'lower_bound', "
"but {} <= {}".format(upper_bound, lower_bound)
assert (
upper_bound > lower_bound
), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format(
upper_bound, lower_bound
)
noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape)

Expand All @@ -315,6 +457,66 @@ def uniform_noise(
return arr_perturbed


def batch_uniform_noise(
arr: np.array,
indices: np.array,
lower_bound: float = 0.02,
upper_bound: Union[None, float] = None,
**kwargs,
) -> np.array:
"""
Add noise to the input at indices as sampled uniformly random from [-lower_bound, lower_bound].
if upper_bound is None, and [lower_bound, upper_bound] otherwise.
Parameters
----------
arr: np.ndarray
Array to be perturbed.
indices: int, sequence, tuple
Array-like, with a subset shape of arr.
lower_bound: float
The lower bound for uniform sampling.
upper_bound: float, optional
The upper bound for uniform sampling.
kwargs: optional
Keyword arguments.
Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""

# Assert dimensions
assert (
len(arr.shape) == 2
), "The array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the features"
assert (
len(indices.shape) == 2
), "The indices array must be 2-dimensional, first dimension corresponding to the batch size, and the second to the indices to perturb"

batch_size = arr.shape[0]
arr_perturbed = copy.copy(arr)

# Sample the noise.
if upper_bound is None:
noise = np.random.uniform(low=-lower_bound, high=lower_bound, size=arr.shape)
else:
assert (
upper_bound > lower_bound
), "Parameter 'upper_bound' needs to be larger than 'lower_bound', " "but {} <= {}".format(
upper_bound, lower_bound
)
noise = np.random.uniform(low=lower_bound, high=upper_bound, size=arr.shape)

# Perturb the array.
arr_perturbed[np.arange(batch_size)[:, None], indices] = (arr_perturbed + noise)[
np.arange(batch_size)[:, None], indices
]

return arr_perturbed


def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array:
"""
Rotate array by some given angle, assumes image type data and channel first layout.
Expand All @@ -336,8 +538,7 @@ def rotation(arr: np.array, perturb_angle: float = 10, **kwargs) -> np.array:

if arr.ndim != 3:
raise ValueError(
"perturb func 'rotation' requires image-type data."
"Check that this perturb_func receives a 3D array."
"perturb func 'rotation' requires image-type data." "Check that this perturb_func receives a 3D array."
)

matrix = cv2.getRotationMatrix2D(
Expand Down Expand Up @@ -523,9 +724,7 @@ def noisy_linear_imputation(
a[out_off_coord_ids, variable_ids] = weight

# Reduce weight for invalid coordinates.
sum_neighbors[np.argwhere(valid == 0).flatten()] = (
sum_neighbors[np.argwhere(valid == 0).flatten()] - weight
)
sum_neighbors[np.argwhere(valid == 0).flatten()] = sum_neighbors[np.argwhere(valid == 0).flatten()] - weight

a[np.arange(len(indices)), np.arange(len(indices))] = -sum_neighbors

Expand Down
Loading

0 comments on commit 53f97c2

Please sign in to comment.