Skip to content

Commit

Permalink
automerge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Dec 7, 2023
2 parents ba7ab51 + 6511c84 commit 79cf6d8
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 90 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ _Quantus is currently under active development so carefully note the Quantus rel

## News and Highlights! :rocket:

- New metrics added: [EfficientMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/efficient_mprt.py) and [SmoothMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/smooth_mprt.py) by [Hedström et al., (2023)](https://openreview.net/forum?id=vVpefYmnsG)
- New metrics added: [EfficientMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/efficient_mprt.py) and [SmoothMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/smooth_mprt.py) by [Hedström et al., (2023)](https://openreview.net/pdf?id=vVpefYmnsG)
- Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases)
- Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html)
- Offers more than **30+ metrics in 6 categories** for XAI evaluation
Expand Down
2 changes: 1 addition & 1 deletion quantus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

# Set the correct version.
__version__ = "0.5.2"
__version__ = "0.5.3"

# Expose quantus.evaluate to the user.
from quantus.evaluation import evaluate
Expand Down
68 changes: 34 additions & 34 deletions quantus/metrics/randomisation/efficient_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,42 +400,42 @@ def __call__(
if self.skip_layers and (l_ix + 1) < n_layers:
continue

# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model, x_full_dataset, y_full_dataset, batch_size
)
# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model, x_full_dataset, y_full_dataset, batch_size
)

# Compute the complexity of explanations of the perturbed model.
self.explanation_scores_by_layer[layer_name] = []
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance_perturbed,
)
self.explanation_scores_by_layer[layer_name].append(score)
pbar.update(1)

# Wrap the model.
random_layer_model_wrapped = utils.get_wrapped_model(
model=random_layer_model,
channel_first=channel_first,
softmax=softmax,
device=device,
model_predict_kwargs=model_predict_kwargs,
)
# Compute the complexity of explanations of the perturbed model.
self.explanation_scores_by_layer[layer_name] = []
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance_perturbed,
)
self.explanation_scores_by_layer[layer_name].append(score)
pbar.update(1)

# Wrap the model.
random_layer_model_wrapped = utils.get_wrapped_model(
model=random_layer_model,
channel_first=channel_first,
softmax=softmax,
device=device,
model_predict_kwargs=model_predict_kwargs,
)

# Predict and save complexity scores of the perturbed model outputs.
self.model_scores_by_layer[layer_name] = []
y_preds = random_layer_model_wrapped.predict(x_full_dataset)
for y_ix, y_pred in enumerate(y_preds):
score = entropy(a=y_pred, x=y_pred)
self.model_scores_by_layer[layer_name].append(score)
# Predict and save complexity scores of the perturbed model outputs.
self.model_scores_by_layer[layer_name] = []
y_preds = random_layer_model_wrapped.predict(x_full_dataset)
for y_ix, y_pred in enumerate(y_preds):
score = entropy(a=y_pred, x=y_pred)
self.model_scores_by_layer[layer_name].append(score)

# Save evaluation scores as the relative rise in complexity.
explanation_scores = list(self.explanation_scores_by_layer.values())
Expand Down
60 changes: 36 additions & 24 deletions quantus/metrics/randomisation/mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
similarity_func: Optional[Callable] = None,
layer_order: str = "top_down",
seed: int = 42,
return_sample_correlation: Optional[bool] = None,
return_average_correlation: bool = False,
return_last_correlation: bool = False,
skip_layers: bool = False,
Expand Down Expand Up @@ -148,6 +149,15 @@ def __init__(
**kwargs,
)

if return_sample_correlation is not None:
print(
"'return_sample_correlation' parameter is deprecated and will be removed in future versions. "
f"Please use 'return_average_correlation' instead. "
f"Setting 'return_average_correlation' to {return_sample_correlation}.",
)
# Use the value of 'return_average_correlation' for 'return_sample_correlation'
return_average_correlation = return_sample_correlation

# Save metric-specific attributes.
if similarity_func is None:
similarity_func = correlation_spearman
Expand Down Expand Up @@ -323,10 +333,6 @@ def __call__(
):
pbar.desc = layer_name

# Skip layers if computing delta.
if self.skip_layers and (l_ix + 1) < n_layers:
continue

if l_ix == 0:

# Generate explanations on original model in batches.
Expand Down Expand Up @@ -354,28 +360,34 @@ def __call__(
self.evaluation_scores["original"].append(score)
pbar.update(1)

self.evaluation_scores[layer_name] = []
# Skip layers if computing delta.
if self.skip_layers and (l_ix + 1) < n_layers:
continue

self.evaluation_scores[layer_name] = []

# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model, x_full_dataset, y_full_dataset, batch_size
)
# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model, x_full_dataset, y_full_dataset, batch_size
)

# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance,
a_perturbed=a_instance_perturbed,
)
self.evaluation_scores[layer_name].append(score)
pbar.update(1)
# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(
a_batch, a_batch_perturbed
):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance,
a_perturbed=a_instance_perturbed,
)
self.evaluation_scores[layer_name].append(score)
pbar.update(1)

if self.return_average_correlation:
self.evaluation_scores = self.recompute_average_correlation_per_sample()
Expand Down
64 changes: 34 additions & 30 deletions quantus/metrics/randomisation/smooth_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,18 +343,14 @@ def __call__(
):
pbar.desc = layer_name

# Skip layers if computing delta.
if self.skip_layers and (l_ix + 1) < n_layers:
continue

if l_ix == 0:

# Generate explanations on original model in batches.
a_original_generator = self.generate_explanations(
model.get_model(),
x_full_dataset,
y_full_dataset,
**kwargs,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the original model.
Expand All @@ -377,31 +373,37 @@ def __call__(
self.evaluation_scores["original"].append(score)
pbar.update(1)

self.evaluation_scores[layer_name] = []

# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model,
x_full_dataset,
y_full_dataset,
**kwargs,
)
# Skip layers if computing delta.
if self.skip_layers and (l_ix + 1) < n_layers:
continue

# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance,
a_perturbed=a_instance_perturbed,
)
self.evaluation_scores[layer_name].append(score)
pbar.update(1)
self.evaluation_scores[layer_name] = []

# Generate explanations on perturbed model in batches.
a_perturbed_generator = self.generate_explanations(
random_layer_model,
x_full_dataset,
y_full_dataset,
**self.explain_func_kwargs,
)

# Compute the similarity of explanations of the perturbed model.
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(
a_batch, a_batch_perturbed
):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
y=None,
s=None,
a=a_instance,
a_perturbed=a_instance_perturbed,
)
self.evaluation_scores[layer_name].append(score)
pbar.update(1)

if self.return_average_correlation:
self.evaluation_scores = self.recompute_average_correlation_per_sample()
Expand Down Expand Up @@ -544,7 +546,9 @@ def custom_preprocess(
return None

a_batch_chunks = []
for a_chunk in self.generate_explanations(model, x_batch, y_batch):
for a_chunk in self.generate_explanations(
model, x_batch, y_batch, **{**kwargs, **self.explain_func_kwargs}
):
a_batch_chunks.extend(a_chunk)
return dict(a_batch=np.asarray(a_batch_chunks))

Expand Down
52 changes: 52 additions & 0 deletions tests/metrics/test_randomisation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
from zennit import attribution as zattr

from quantus.functions.explanation_func import explain
from quantus.functions import complexity_func, n_bins_func
Expand Down Expand Up @@ -745,6 +746,57 @@ def test_model_parameter_randomisation(
},
{"exception": ValueError},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
{
"init": {
"layer_order": "independent",
"similarity_func": correlation_spearman,
"normalise": True,
"abs": True,
"disable_warnings": True,
"return_average_correlation": False,
"return_last_correlation": True,
"skip_layers": True,
},
"call": {
"explain_func": explain,
"explain_func_kwargs": {
"shape": (8, 1, 28, 28),
"canonizer": None,
"composite": None,
"attributor": zattr.Gradient,
"xai_lib": "zennit",
},
},
},
{"min": -1.0, "max": 1.0},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
{
"init": {
"layer_order": "independent",
"similarity_func": correlation_spearman,
"normalise": True,
"abs": True,
"disable_warnings": True,
"return_average_correlation": False,
"return_last_correlation": True,
"skip_layers": True,
},
"call": {
"explain_func": explain,
"explain_func_kwargs": {
"attributor": zattr.IntegratedGradients,
"xai_lib": "zennit",
},
},
},
{"min": -1.0, "max": 1.0},
),
],
)
def test_smooth_model_parameter_randomisation(
Expand Down

0 comments on commit 79cf6d8

Please sign in to comment.