diff --git a/README.md b/README.md index ea5efa81..05d37f43 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ _Quantus is currently under active development so carefully note the Quantus rel ## News and Highlights! :rocket: +- If you want to contribute/ improve/ extend Quantus, join our [Discord](https://discord.gg/HB77krUE)! - New metrics added: [EfficientMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/efficient_mprt.py) and [SmoothMPRT](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/randomisation/smooth_mprt.py) by [Hedström et al., (2023)](https://openreview.net/pdf?id=vVpefYmnsG) - Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) - Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html) diff --git a/pyproject.toml b/pyproject.toml index 59bd13ab..eb94a04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ tests = [ "coverage>=7.2.3", "flake8<=4.0.1; python_version == '3.7'", "flake8>=6.0.0; python_version > '3.7'", - "pytest>=7.3.1", + "pytest<=7.4.4", "pytest-cov>=4.0.0", "pytest-lazy-fixture>=0.6.3", "pytest-mock==3.10.0", diff --git a/quantus/helpers/asserts.py b/quantus/helpers/asserts.py index fe059811..66ee2930 100644 --- a/quantus/helpers/asserts.py +++ b/quantus/helpers/asserts.py @@ -286,6 +286,12 @@ def assert_value_smaller_than_input_size( ------- None """ + if len(x.shape) == 2: + if value >= np.prod(x.shape[1:]): + raise ValueError( + f"'{value_name}' must be smaller than input size." + f" [{value} >= {np.prod(x.shape[1:])}]" + ) if value >= np.prod(x.shape[2:]): raise ValueError( f"'{value_name}' must be smaller than input size." diff --git a/quantus/metrics/faithfulness/infidelity.py b/quantus/metrics/faithfulness/infidelity.py index ef621a60..6c12f61f 100644 --- a/quantus/metrics/faithfulness/infidelity.py +++ b/quantus/metrics/faithfulness/infidelity.py @@ -60,7 +60,7 @@ class Infidelity(Metric[List[float]]): """ name = "Infidelity" - data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} + data_applicability = {DataType.IMAGE} model_applicability = {ModelType.TORCH, ModelType.TF} score_direction = ScoreDirection.LOWER evaluation_category = EvaluationCategory.FAITHFULNESS diff --git a/quantus/metrics/randomisation/efficient_mprt.py b/quantus/metrics/randomisation/efficient_mprt.py index 7c9ca368..4e4ecafe 100644 --- a/quantus/metrics/randomisation/efficient_mprt.py +++ b/quantus/metrics/randomisation/efficient_mprt.py @@ -127,7 +127,7 @@ def __init__( Indicates whether normalise operation is applied on the attribution, default=True. normalise_func: callable Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. + If normalise_func=None, the default value is used, default=normalise_by_average_second_moment_estimate. normalise_func_kwargs: dict Keyword arguments to be passed to normalise_func on call, default={}. return_aggregate: boolean diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index ef32f08c..66a247be 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -29,6 +29,7 @@ from scipy import stats from quantus.functions.similarity_func import correlation_spearman +from quantus.functions.normalise_func import normalise_by_average_second_moment_estimate from quantus.helpers import asserts, warn, utils from quantus.helpers.enums import ( DataType, @@ -130,7 +131,7 @@ def __init__( Indicates whether normalise operation is applied on the attribution, default=True. normalise_func: callable Attribution normalisation function applied in case normalise=True. - If normalise_func=None, the default value is used, default=normalise_by_max. + If normalise_func=None, the default value is used, default=normalise_by_average_second_moment_estimate. normalise_func_kwargs: dict Keyword arguments to be passed to normalise_func on call, default={}. return_aggregate: boolean @@ -163,7 +164,16 @@ def __init__( # Save metric-specific attributes. if similarity_func is None: similarity_func = correlation_spearman + if normalise_func is None: + normalise_func = normalise_by_average_second_moment_estimate + + if normalise_func_kwargs is None: + normalise_func_kwargs = {} + self.similarity_func = similarity_func + self.normalise_func = normalise_func + self.abs = abs + self.normalise_func_kwargs = normalise_func_kwargs self.layer_order = layer_order self.seed = seed self.nr_samples = nr_samples