Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XAI Inverse Estimation #337

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b593012
updated two metrics and base implementation for xai inverse estimation
annahedstroem Oct 9, 2023
492ea8d
finished aggregation of the inverse implementation
annahedstroem Oct 9, 2023
15efaf6
wip inverse estimation
annahedstroem Nov 15, 2023
83c6554
fixes to inverse estimation method
annahedstroem Nov 24, 2023
ba7ab51
fixes, tests and implementation
annahedstroem Dec 1, 2023
79cf6d8
automerge issues
annahedstroem Dec 7, 2023
fb3d5ef
tests passing, most todos removed
annahedstroem Dec 7, 2023
1125384
minor fixes to InverseEstimation class
annahedstroem Dec 7, 2023
6cc1cda
Update inverse_estimation.py
annahedstroem Feb 23, 2024
0a640fb
Merge branch 'main' into xai-inverse-estimation
annahedstroem Feb 23, 2024
279fb38
remove inverse_estimation edits in region_perturbation.py
annahedstroem Feb 23, 2024
a14c109
name update region_perturbation.py
annahedstroem Feb 23, 2024
c4032d9
remove metric init requirements in inverse_estimation.py
annahedstroem Feb 23, 2024
dd74095
added a ormalise_func_kwargs attribute of base class, was missing
annahedstroem Feb 23, 2024
4b763b8
added second inverse method
annahedstroem Feb 23, 2024
86d7615
added second inverse method -v2
annahedstroem Feb 23, 2024
d9cb2f3
enable batching
annahedstroem Feb 23, 2024
07c9925
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 1, 2024
6aeba15
update method name
annahedstroem Mar 15, 2024
02e3d5a
added inverse_method as an arg
annahedstroem Mar 15, 2024
4da473a
batch update
annahedstroem Mar 15, 2024
3c123c7
updated inverse with batch
annahedstroem Mar 15, 2024
3ccf004
Merge branch 'xai-inverse-estimation' of https://github.com/understan…
annahedstroem Mar 15, 2024
05efd24
added wrapper
annahedstroem Mar 15, 2024
3235507
added wrapper
annahedstroem Mar 15, 2024
f415b53
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 15, 2024
4a7764f
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
9b1de09
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
3ec97cc
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
dfcdae4
update from True to False in assert in inverse_estimation.py
annahedstroem Mar 24, 2024
3d06f90
bugfix a_batch shape inverse_estimation.py
annahedstroem Mar 24, 2024
76d9fdc
bugfix formatting on titanic dataset
annahedstroem Mar 25, 2024
bf7ab44
merge fixes
annahedstroem Mar 25, 2024
d6a9ed9
fixed tests for inverse estimation, debugged shapes
annahedstroem Mar 25, 2024
b4b0156
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 25, 2024
138dc46
merge post-fixes, remove old files
annahedstroem Mar 25, 2024
b2c7f27
add print statement
annahedstroem Mar 25, 2024
b7215ac
added base call kwargs as class attributes to base and rewrite invers…
annahedstroem Mar 25, 2024
77eb1a8
replace assert with warning
annahedstroem Mar 25, 2024
66d75af
bugfix evaluate_batch
annahedstroem Mar 25, 2024
6f20997
small fixes wrt assert/ warns
annahedstroem Mar 25, 2024
0cddcd6
added plotting
annahedstroem Mar 26, 2024
e248ad4
fixes
annahedstroem Mar 26, 2024
e7e4493
eval func check and channel first fix
annahedstroem Mar 26, 2024
7214701
added feature for mean/ AUC calc for localisation metric for inverse …
annahedstroem Mar 27, 2024
a516928
add tests
annahedstroem Mar 27, 2024
a55ef15
plotting updates, tiny
annahedstroem Mar 27, 2024
99b2fa1
include area_between_curves flag inverse_estimation.py
annahedstroem Apr 4, 2024
3d27467
Update consistency.py
annahedstroem Apr 14, 2024
7ae5ecf
Update region_perturbation.py
annahedstroem Apr 16, 2024
4dd68ce
Update region_perturbation.py
annahedstroem Apr 16, 2024
c823d4a
Update return typing region_perturbation.py
annahedstroem Apr 16, 2024
24ec8fa
Update region_perturbation.py
annahedstroem Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ It is possible to limit the scope of testing to specific sections of the codebas
Faithfulness metrics using python3.9 (make sure the python versions match in your environment):

```bash
python3 -m tox run -e py39 -- -m evaluate_func -s
python3 -m tox run -e py39 -- -m faithfulness -s
```

For a complete overview of the possible testing scopes, please refer to `pytest.ini`.
Expand Down
18 changes: 9 additions & 9 deletions quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
import numpy as np
import pandas as pd

from quantus.helpers import asserts
from quantus.helpers import utils
from quantus.helpers import warn
from quantus.helpers import asserts, utils, warn
from quantus.helpers.model.model_interface import ModelInterface
from quantus.functions.explanation_func import explain

Expand Down Expand Up @@ -162,6 +160,8 @@ def evaluate(

if call_kwargs is None:
call_kwargs = {"call_kwargs_empty": {}}
elif not isinstance(call_kwargs, Dict):
raise TypeError("call_kwargs type should be of Dict[str, Dict] (if not None).")

elif not isinstance(call_kwargs, Dict):
raise TypeError("call_kwargs type should be of Dict[str, Dict] (if not None).")
Expand Down Expand Up @@ -205,7 +205,7 @@ def evaluate(
a_batch = utils.expand_attribution_channel(a_batch, x_batch)

# Asserts.
asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch)
warn.warn_attributions(a_batch=a_batch, x_batch=x_batch)

elif isinstance(value, Dict):

Expand All @@ -226,7 +226,7 @@ def evaluate(
a_batch = utils.expand_attribution_channel(a_batch, x_batch)

# Asserts.
asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch)
warn.warn_attributions(a_batch=a_batch, x_batch=x_batch)

elif isinstance(value, np.ndarray):
explain_funcs[method] = explain
Expand All @@ -241,11 +241,11 @@ def evaluate(
if explain_func_kwargs is None:
explain_func_kwargs = {}

for (metric, metric_func) in metrics.items():
for metric, metric_func in metrics.items():

results[method][metric] = {}

for (call_kwarg_str, call_kwarg) in call_kwargs.items():
for call_kwarg_str, call_kwarg in call_kwargs.items():

if verbose:
print(
Expand Down Expand Up @@ -287,8 +287,8 @@ def evaluate(
# Clean up the results if there is only one call_kwarg.
for method, value in xai_methods.items():
results_ordered[method] = {}
for (metric, metric_func) in metrics.items():
for (call_kwarg_str, call_kwarg) in call_kwargs.items():
for metric, metric_func in metrics.items():
for call_kwarg_str, call_kwarg in call_kwargs.items():
results_ordered[method][metric] = results[method][metric][
call_kwarg_str
]
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def f_reduce_axes(a):
inputs = inputs.cpu()

inputs_numpy = inputs.detach().numpy()

for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1)
Expand Down
69 changes: 1 addition & 68 deletions quantus/helpers/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


from typing import Callable, Tuple, Sequence, Union
import warnings
import numpy as np


Expand Down Expand Up @@ -128,74 +129,6 @@ def assert_layer_order(layer_order: str) -> None:
assert layer_order in ["top_down", "bottom_up", "independent"]


def assert_attributions(x_batch: np.array, a_batch: np.array) -> None:
"""
Asserts on attributions, assumes channel first layout.

Parameters
----------
x_batch: np.ndarray
The batch of input to compare the shape of the attributions with.
a_batch: np.ndarray
The batch of attributions.

Returns
-------
None
"""
assert (
type(a_batch) == np.ndarray
), "Attributions 'a_batch' should be of type np.ndarray."
assert np.shape(x_batch)[0] == np.shape(a_batch)[0], (
"The inputs 'x_batch' and attributions 'a_batch' should "
"include the same number of samples."
"{} != {}".format(np.shape(x_batch)[0], np.shape(a_batch)[0])
)
assert np.ndim(x_batch) == np.ndim(a_batch), (
"The inputs 'x_batch' and attributions 'a_batch' should "
"have the same number of dimensions."
"{} != {}".format(np.ndim(x_batch), np.ndim(a_batch))
)
a_shape = [s for s in np.shape(a_batch)[1:] if s != 1]
x_shape = [s for s in np.shape(x_batch)[1:]]
assert a_shape[0] == x_shape[0] or a_shape[-1] == x_shape[-1], (
"The dimensions of attribution and input per sample should correspond in either "
"the first or last dimensions, but got shapes "
"{} and {}".format(a_shape, x_shape)
)
assert all([a in x_shape for a in a_shape]), (
"All attribution dimensions should be included in the input dimensions, "
"but got shapes {} and {}".format(a_shape, x_shape)
)
assert all(
[
x_shape.index(a) > x_shape.index(a_shape[i])
for a in a_shape
for i in range(a_shape.index(a))
]
), (
"The dimensions of the attribution must correspond to dimensions of the input in the same order, "
"but got shapes {} and {}".format(a_shape, x_shape)
)
assert not np.all((a_batch == 0)), (
"The elements in the attribution vector are all equal to zero, "
"which may cause inconsistent results since many metrics rely on ordering. "
"Recompute the explanations."
)
assert not np.all((a_batch == 1.0)), (
"The elements in the attribution vector are all equal to one, "
"which may cause inconsistent results since many metrics rely on ordering. "
"Recompute the explanations."
)
assert len(set(a_batch.flatten().tolist())) > 1, (
"The attributions are uniformly distributed, "
"which may cause inconsistent results since many "
"metrics rely on ordering."
"Recompute the explanations."
)
assert not np.all((a_batch < 0.0)), "Attributions should not all be less than zero."


def assert_segmentations(x_batch: np.array, s_batch: np.array) -> None:
"""
Asserts on segmentations, assumes channel first layout.
Expand Down
11 changes: 10 additions & 1 deletion quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"Faithfulness Correlation": FaithfulnessCorrelation,
"Faithfulness Estimate": FaithfulnessEstimate,
"Pixel-Flipping": PixelFlipping,
"Region Segmentation": RegionPerturbation,
"Region Perturbation": RegionPerturbation,
"Monotonicity-Arya": Monotonicity,
"Monotonicity-Nguyen": MonotonicityCorrelation,
"Selectivity": Selectivity,
Expand Down Expand Up @@ -74,6 +74,15 @@
},
}

# Quantus metrics that include a step-wise 'masking'/ perturbation that is
# based on attribution order/ ranking (and not magnitude).
AVAILABLE_INVERSE_ESTIMATION_METRICS = {
"Pixel-Flipping": PixelFlipping,
"Region Perturbation": RegionPerturbation, # order = 'morf'
"ROAD": ROAD, # return_only_values = True
"Selectivity": Selectivity,
}
#

AVAILABLE_PERTURBATION_FUNCTIONS = {
"baseline_replacement_by_indices": baseline_replacement_by_indices,
Expand Down
32 changes: 32 additions & 0 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,38 @@ def sample(
)
return model_copy

def perturb_layer_weights(self, layer_idx: int, noise: float):
"""
Perturb the weights of a specific layer in a PyTorch model.

Parameters
----------
model : torch.nn.Module
The PyTorch model.
layer_idx : int
The index of the layer to perturb.
noise : float
The standard deviation of the Gaussian noise to add to the weights.

Returns
-------
None
"""
original_parameters = self.state_dict()
model_copy = deepcopy(self.model)
model_copy.load_state_dict(original_parameters)

# Get the specific layer.
layer = list(model_copy.modules())[layer_idx]

# Generate Gaussian noise.
noise_tensor = torch.randn_like(layer.weight) * noise

# Add the noise to the layer's weights.
layer.weight.data.add_(noise_tensor)

return model_copy

def add_mean_shift_to_first_layer(
self,
input_shift: Union[int, float],
Expand Down
36 changes: 32 additions & 4 deletions quantus/helpers/perturbation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,23 @@ def __call__(
def make_perturb_func(
perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs
) -> PerturbFunc | functools.partial:
"""A utility function to save few lines of code during perturbation metric initialization."""
"""
A utility function to save few lines of code during perturbation metric initialization.

Parameters
----------
perturb_func: callable
Perturbation function.
perturb_func_kwargs: dict
Perturbation function kwargs.
kwargs: dict
Perturbation metric kwargs.

Returns
-------
perturb_func: callable
Perturbation function.
"""
if perturb_func_kwargs is not None:
func_kwargs = kwargs.copy()
func_kwargs.update(perturb_func_kwargs)
Expand All @@ -41,7 +57,19 @@ def make_perturb_func(
def make_changed_prediction_indices_func(
return_nan_when_prediction_changes: bool,
) -> Callable[[ModelInterface, np.ndarray, np.ndarray], List[int]]:
"""A utility function to improve static analysis."""
"""
A utility function to improve static analysis.

Parameters
----------
return_nan_when_prediction_changes: boolean
Indicates if metric should return NaN when model prediction changes due to perturbation.

Returns
-------
changed_prediction_indices: callable
Function that returns indices in batch, for which predicted label has changed after applying perturbation.
"""
return functools.partial(
changed_prediction_indices,
return_nan_when_prediction_changes=return_nan_when_prediction_changes,
Expand All @@ -62,15 +90,15 @@ def changed_prediction_indices(
----------
return_nan_when_prediction_changes:
Instance attribute of perturbation metrics.
model:
model: ModelInterface
Model to be used for prediction.
x_batch:
Batch of original inputs provided by user.
x_perturbed:
Batch of inputs after applying perturbation.

Returns
-------

changed_idx:
List of indices in batch, for which predicted label has changed afer.

Expand Down
61 changes: 61 additions & 0 deletions quantus/helpers/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,67 @@ def plot_pixel_flipping_experiment(
plt.show()


def plot_inverse_curves(
y_batch: np.ndarray,
scores_ori: List[Any],
scores_inv: List[Any],
single_class: Union[int, None] = None,
*args,
**kwargs,
) -> None:
"""
Plot the pixel-flipping experiment as done in paper:

References:
1) Bach, Sebastian, et al. "On pixel-wise explanations for non-linear classifier
decisions by layer-wise relevance propagation." PloS one 10.7 (2015): e0130140.

Parameters
----------
y_batch: np.ndarray
The list of true labels.
scores_ori: list
The list of evalution scores.
scores_inv: list
The list of evalution scores (inverse curve).
single_class: integer, optional
An integer to specify the label to plot.
args: optional
Arguments.
kwargs: optional
Keyword arguments.

Returns
-------
None
"""

fig = plt.figure(figsize=(8, 6))
if single_class is None:
for c in np.unique(y_batch):
indices = np.where(y_batch == c)
plt.plot(
np.linspace(0, 1, len(scores_ori[0])),
np.mean(np.array(scores_ori)[indices], axis=0),
label=f"Original curve: {str(c)} ({indices[0].size} samples)",
)
plt.plot(
np.linspace(0, 1, len(scores_inv[0])),
np.mean(np.array(scores_inv)[indices], axis=0),
label=f"Inverse curve: {str(c)} ({indices[0].size} samples)",
)
plt.xlabel("Fraction of pixels flipped")
plt.ylabel("Mean Prediction")
plt.gca().set_yticklabels(
["{:.0f}%".format(x * 100) for x in plt.gca().get_yticks()]
)
plt.gca().set_xticklabels(
["{:.0f}%".format(x * 100) for x in plt.gca().get_xticks()]
)
plt.legend()
plt.show()


def plot_selectivity_experiment(results: Dict[str, List[Any]], *args, **kwargs) -> None:
"""
Plot the selectivity experiment as done in paper:
Expand Down
Loading
Loading