Skip to content

Commit

Permalink
resolving mypy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davor10105 committed Oct 8, 2024
1 parent b506aa5 commit 6591aab
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 115 deletions.
4 changes: 2 additions & 2 deletions quantus/functions/perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def gaussian_noise(

def batch_gaussian_noise(
arr: np.array,
indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]],
indices: np.array,
perturb_mean: float = 0.0,
perturb_std: float = 0.01,
**kwargs,
Expand Down Expand Up @@ -459,7 +459,7 @@ def uniform_noise(

def batch_uniform_noise(
arr: np.array,
indices: Tuple[slice, ...], # Alt. Union[int, Sequence[int], Tuple[np.array]],
indices: np.array,
lower_bound: float = 0.02,
upper_bound: Union[None, float] = None,
**kwargs,
Expand Down
20 changes: 10 additions & 10 deletions quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.
# Quantus project URL: https://github.com/understandable-machine-intelligence-lab/Quantus

from typing import Union
from typing import Union, List

import numpy as np
import scipy
Expand Down Expand Up @@ -105,7 +105,7 @@ def correlation_kendall_tau(a: np.array, b: np.array, batched: bool = False, **k
return scipy.stats.kendalltau(a, b)[0]


def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float:
def distance_euclidean(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]:
"""
Calculate Euclidean distance of two images (or explanations).
Expand All @@ -120,13 +120,13 @@ def distance_euclidean(a: np.array, b: np.array, **kwargs) -> float:
Returns
-------
float
The similarity score.
Union[float, np.array]
The similarity score or a batch of similarity scores.
"""
return ((a - b) ** 2).sum(axis=-1) ** 0.5


def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float:
def distance_manhattan(a: np.array, b: np.array, **kwargs) -> Union[float, np.array]:
"""
Calculate Manhattan distance of two images (or explanations).
Expand All @@ -141,8 +141,8 @@ def distance_manhattan(a: np.array, b: np.array, **kwargs) -> float:
Returns
-------
float
The similarity score.
Union[float, np.array]
The similarity score or a batch of similarity scores.
"""
return abs(a - b).sum(-1)

Expand Down Expand Up @@ -272,7 +272,7 @@ def cosine(a: np.array, b: np.array, **kwargs) -> float:
return scipy.spatial.distance.cosine(u=a, v=b)


def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float:
def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> Union[float, List[float]]:
"""
Calculate Structural Similarity Index Measure of two images (or explanations).
Expand All @@ -289,8 +289,8 @@ def ssim(a: np.array, b: np.array, batched: bool = False, **kwargs) -> float:
Returns
-------
float
The similarity score.
Union[float, List[float]]
The similarity score, returns a list if batched.
"""

def inner(aa: np.array, bb: np.array) -> float:
Expand Down
7 changes: 2 additions & 5 deletions quantus/helpers/perturbation_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import List, TYPE_CHECKING, Callable, Mapping
from typing import List, TYPE_CHECKING, Callable, Mapping, Optional
import numpy as np
import functools

Expand All @@ -18,11 +18,8 @@ class PerturbFunc(Protocol):
def __call__(
self,
arr: np.ndarray,
indices: np.ndarray,
indexed_axes: np.ndarray,
**kwargs,
) -> np.ndarray:
...
) -> np.ndarray: ...


def make_perturb_func(
Expand Down
4 changes: 2 additions & 2 deletions quantus/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def _unpad_array(
return unpadded_arr


def get_padding_size(dim: int, patch_size: int) -> Tuple[int]:
def get_padding_size(dim: int, patch_size: int) -> Tuple[int, int]:
"""
Calculate the padding size (optionally) needed for a patch_size.
Expand All @@ -727,7 +727,7 @@ def get_padding_size(dim: int, patch_size: int) -> Tuple[int]:
Returns
-------
Tuple[int]
Tuple[int, int]
A tuple of values passed to the utils._pad_array method for a particular dimension.
"""
modulo = dim % patch_size
Expand Down
27 changes: 7 additions & 20 deletions quantus/metrics/faithfulness/faithfulness_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def __init__(
self.similarity_func = similarity_func
self.nr_runs = nr_runs
self.subset_size = subset_size
self.perturb_func = make_perturb_func(
perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
)
self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)

# Asserts and warnings.
if not self.disable_warnings:
Expand Down Expand Up @@ -294,9 +292,7 @@ def custom_preprocess(self, x_batch: np.ndarray, **kwargs) -> None:
returning a custom preprocess batch (custom_preprocess_batch).
"""
# Asserts.
asserts.assert_value_smaller_than_input_size(
x=x_batch, value=self.subset_size, value_name="subset_size"
)
asserts.assert_value_smaller_than_input_size(x=x_batch, value=self.subset_size, value_name="subset_size")

def evaluate_batch(
self,
Expand Down Expand Up @@ -338,9 +334,7 @@ def evaluate_batch(
n_features = a_batch.shape[-1]

# Predict on input.
x_input = model.shape_input(
x_batch, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True)
y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]

pred_deltas = []
Expand All @@ -351,10 +345,7 @@ def evaluate_batch(
for i_ix in range(self.nr_runs):
# Randomly mask by subset size.
a_ix = np.stack(
[
np.random.choice(n_features, self.subset_size, replace=False)
for _ in range(batch_size)
],
[np.random.choice(n_features, self.subset_size, replace=False) for _ in range(batch_size)],
axis=0,
)
x_perturbed = self.perturb_func(
Expand All @@ -365,14 +356,10 @@ def evaluate_batch(

# Check if the perturbation caused change
for x_element, x_perturbed_element in zip(x_batch, x_perturbed):
warn.warn_perturbation_caused_no_change(
x=x_element, x_perturbed=x_perturbed_element
)
warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element)

# Predict on perturbed input x.
x_input = model.shape_input(
x_perturbed, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True)
y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch]
pred_deltas.append(y_pred - y_pred_perturb)

Expand All @@ -381,6 +368,6 @@ def evaluate_batch(
pred_deltas = np.stack(pred_deltas, axis=1)
att_sums = np.stack(att_sums, axis=1)

similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True)
similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True)

return similarity.tolist()
27 changes: 7 additions & 20 deletions quantus/metrics/faithfulness/faithfulness_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,13 @@ def __init__(
perturb_func = batch_baseline_replacement_by_indices
self.similarity_func = similarity_func
self.features_in_step = features_in_step
self.perturb_func = make_perturb_func(
perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
)
self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)

# Asserts and warnings.
if not self.disable_warnings:
warn.warn_parameterisation(
metric_name=self.__class__.__name__,
sensitive_params=(
"baseline value 'perturb_baseline' and similarity function "
"'similarity_func'"
),
sensitive_params=("baseline value 'perturb_baseline' and similarity function " "'similarity_func'"),
citation=(
"Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability"
" with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018)"
Expand Down Expand Up @@ -326,9 +321,7 @@ def evaluate_batch(
a_indices = np.argsort(-a_batch, axis=1)

# Predict on input.
x_input = model.shape_input(
x_batch, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True)
y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]

n_perturbations = math.ceil(n_features / self.features_in_step)
Expand All @@ -339,9 +332,7 @@ def evaluate_batch(
# Perturb input by indices of attributions.
a_ix = a_indices[
:,
perturbation_step_index
* self.features_in_step : (perturbation_step_index + 1)
* self.features_in_step,
perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step,
]
x_perturbed = self.perturb_func(
arr=x_batch.reshape(batch_size, -1),
Expand All @@ -351,14 +342,10 @@ def evaluate_batch(

# Check if the perturbation caused change
for x_element, x_perturbed_element in zip(x_batch, x_perturbed):
warn.warn_perturbation_caused_no_change(
x=x_element, x_perturbed=x_perturbed_element
)
warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element)

# Predict on perturbed input x.
x_input = model.shape_input(
x_perturbed, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True)
y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch]
pred_deltas.append(y_pred - y_pred_perturb)

Expand All @@ -367,6 +354,6 @@ def evaluate_batch(
pred_deltas = np.stack(pred_deltas, axis=1)
att_sums = np.stack(att_sums, axis=1)

similarity = self.similarity_func(a=att_sums, b=pred_deltas, batched=True)
similarity: np.array = self.similarity_func(a=att_sums, b=pred_deltas, batched=True)

return similarity.tolist()
10 changes: 5 additions & 5 deletions quantus/metrics/faithfulness/irof.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def evaluate_batch(

# Segment image.
segments_batch = []
s_indices_batch = []
s_indices_batch_list = []
for x, a in zip(x_batch, a_batch):
segments = utils.get_superpixel_segments(
img=np.moveaxis(x, 0, -1).astype("double"),
Expand All @@ -352,14 +352,14 @@ def evaluate_batch(

# Sort segments based on the mean attribution (descending order).
s_indices = np.argsort(-att_segs)
s_indices_batch.append(s_indices)
s_indices_batch_list.append(s_indices)
segments_batch = np.stack(segments_batch, axis=0)
max_segments_len = max([len(s_indices) for s_indices in s_indices_batch])
max_segments_len = max([len(s_indices) for s_indices in s_indices_batch_list])
mask_preds_batch = np.array(
[[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch]
[[1.0] * len(s_indices) + [0] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list]
)
s_indices_batch = np.array(
[s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch]
[s_indices.tolist() + [-1] * (max_segments_len - len(s_indices)) for s_indices in s_indices_batch_list]
)

preds = []
Expand Down
27 changes: 8 additions & 19 deletions quantus/metrics/faithfulness/monotonicity_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def __init__(
self.eps = eps
self.nr_samples = nr_samples
self.features_in_step = features_in_step
self.perturb_func = make_perturb_func(
perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
)
self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)

# Asserts and warnings.
if not self.disable_warnings:
Expand Down Expand Up @@ -334,9 +332,7 @@ def evaluate_batch(
"""

# Predict on input x.
x_input = model.shape_input(
x_batch, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True)
batch_size = x_batch.shape[0]
y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]

Expand All @@ -363,9 +359,7 @@ def evaluate_batch(
# Perturb input by indices of attributions.
a_ix = a_indices[
:,
perturbation_step_index
* self.features_in_step : (perturbation_step_index + 1)
* self.features_in_step,
perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step,
]

y_pred_perturbs = []
Expand All @@ -378,23 +372,18 @@ def evaluate_batch(

# Check if the perturbation caused change
for x_element, x_perturbed_element in zip(x_batch, x_perturbed):
warn.warn_perturbation_caused_no_change(
x=x_element, x_perturbed=x_perturbed_element
)
warn.warn_perturbation_caused_no_change(x=x_element, x_perturbed=x_perturbed_element)

# Predict on perturbed input x.
x_input = model.shape_input(
x_perturbed, x_batch.shape, channel_first=True, batched=True
)
x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True)
y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch]
y_pred_perturbs.append(y_pred_perturb)
y_pred_perturbs = np.stack(y_pred_perturbs, axis=1)

vars.append(
np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred
)
vars.append(np.mean((y_pred_perturbs - y_pred[:, None]) ** 2, axis=1) * inv_pred)
atts.append(a_batch[np.arange(batch_size)[:, None], a_ix].sum(axis=1))
vars = np.stack(vars, axis=1)
atts = np.stack(atts, axis=1)

return self.similarity_func(a=atts, b=vars, batched=True).tolist()
similarities: np.array = self.similarity_func(a=atts, b=vars, batched=True)
return similarities.tolist()
Loading

0 comments on commit 6591aab

Please sign in to comment.