Skip to content

Commit

Permalink
Merge branch 'main' into xai-inverse-estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem authored Feb 23, 2024
2 parents 6cc1cda + 47a41b1 commit 0a640fb
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ _Quantus is currently under active development so carefully note the Quantus rel

## News and Highlights! :rocket:

- If you want to contribute/ improve/ extend Quantus, join our [Discord](https://discord.gg/HB77krUE)!
- New metrics added: [EfficientMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/efficient_mprt.py) and [SmoothMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/smooth_mprt.py) by [Hedström et al., (2023)](https://openreview.net/pdf?id=vVpefYmnsG)
- Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases)
- Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ tests = [
"coverage>=7.2.3",
"flake8<=4.0.1; python_version == '3.7'",
"flake8>=6.0.0; python_version > '3.7'",
"pytest>=7.3.1",
"pytest<=7.4.4",
"pytest-cov>=4.0.0",
"pytest-lazy-fixture>=0.6.3",
"pytest-mock==3.10.0",
Expand Down
6 changes: 6 additions & 0 deletions quantus/helpers/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def assert_value_smaller_than_input_size(
-------
None
"""
if len(x.shape) == 2:
if value >= np.prod(x.shape[1:]):
raise ValueError(
f"'{value_name}' must be smaller than input size."
f" [{value} >= {np.prod(x.shape[1:])}]"
)
if value >= np.prod(x.shape[2:]):
raise ValueError(
f"'{value_name}' must be smaller than input size."
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/faithfulness/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Infidelity(Metric[List[float]]):
"""

name = "Infidelity"
data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR}
data_applicability = {DataType.IMAGE}
model_applicability = {ModelType.TORCH, ModelType.TF}
score_direction = ScoreDirection.LOWER
evaluation_category = EvaluationCategory.FAITHFULNESS
Expand Down
2 changes: 1 addition & 1 deletion quantus/metrics/randomisation/efficient_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
Indicates whether normalise operation is applied on the attribution, default=True.
normalise_func: callable
Attribution normalisation function applied in case normalise=True.
If normalise_func=None, the default value is used, default=normalise_by_max.
If normalise_func=None, the default value is used, default=normalise_by_average_second_moment_estimate.
normalise_func_kwargs: dict
Keyword arguments to be passed to normalise_func on call, default={}.
return_aggregate: boolean
Expand Down
12 changes: 11 additions & 1 deletion quantus/metrics/randomisation/smooth_mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from scipy import stats

from quantus.functions.similarity_func import correlation_spearman
from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate
from quantus.helpers import asserts, warn, utils
from quantus.helpers.enums import (
DataType,
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
Indicates whether normalise operation is applied on the attribution, default=True.
normalise_func: callable
Attribution normalisation function applied in case normalise=True.
If normalise_func=None, the default value is used, default=normalise_by_max.
If normalise_func=None, the default value is used, default=normalise_by_average_second_moment_estimate.
normalise_func_kwargs: dict
Keyword arguments to be passed to normalise_func on call, default={}.
return_aggregate: boolean
Expand Down Expand Up @@ -163,7 +164,16 @@ def __init__(
# Save metric-specific attributes.
if similarity_func is None:
similarity_func = correlation_spearman
if normalise_func is None:
normalise_func = normalise_by_average_second_moment_estimate

if normalise_func_kwargs is None:
normalise_func_kwargs = {}

self.similarity_func = similarity_func
self.normalise_func = normalise_func
self.abs = abs
self.normalise_func_kwargs = normalise_func_kwargs
self.layer_order = layer_order
self.seed = seed
self.nr_samples = nr_samples
Expand Down

0 comments on commit 0a640fb

Please sign in to comment.