Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPRT metrics: EfficientMPRT and SmoothMPRT #308

Merged
merged 12 commits into from
Nov 24, 2023
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ It is possible to limit the scope of testing to specific sections of the codebas
Faithfulness metrics using python3.9 (make sure the python versions match in your environment):

```bash
python3 -m tox run -e py39 -- -m faithfulness -s
python3 -m tox run -e py39 -- -m smprt -s
```

For a complete overview of the possible testing scopes, please refer to `pytest.ini`.
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@ _Quantus is currently under active development so carefully note the Quantus rel

## News and Highlights! :rocket:

- Released a new version [v0.4.3](https://github.com/understandable-machine-intelligence-lab/Quantus/releases)
- Released a new version [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases)
- Accepted to Journal of Machine Learning Research (MLOSS), read the [paper](https://jmlr.org/papers/v24/22-0142.html)
- Offers more than **30+ metrics in 6 categories** for XAI evaluation
- Supports different data types (image, time-series, tabular, NLP next up!) and models (PyTorch, TensorFlow)
- Extended built-in support for explanation methods ([captum](https://captum.ai/), [tf-explain](https://tf-explain.readthedocs.io/en/latest/) and [zennit](https://github.com/chr5tphr/zennit))
- New optimisations to help speed up computation, see API reference [here](https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.base_batched.html)

See [here](https://github.com/understandable-machine-intelligence-lab/Quantus/releases) for the latest release(s).

## Citation

Expand Down
2 changes: 1 addition & 1 deletion quantus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

# Set the correct version.
__version__ = "0.4.5"
__version__ = "0.5.0"

# Expose quantus.evaluate to the user.
from quantus.evaluation import evaluate
Expand Down
5 changes: 4 additions & 1 deletion quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def evaluate(
return None

if call_kwargs is None:
call_kwargs = {'call_kwargs_empty': {}}
call_kwargs = {"call_kwargs_empty": {}}
elif not isinstance(call_kwargs, Dict):
raise TypeError("xai_methods type is not Dict[str, Dict].")

Expand All @@ -99,6 +99,9 @@ def evaluate(

explain_funcs[method] = value
explain_func = value
assert (
explain_func_kwargs is not None
), "Pass explain_func_kwargs as a dictionary."

# Asserts.
asserts.assert_explain_func(explain_func=explain_func)
Expand Down
102 changes: 102 additions & 0 deletions quantus/functions/complexity_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""This module holds a collection of functions to compute the complexity (of explanations)."""

# This file is part of Quantus.
# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import scipy
import numpy as np


def entropy(a: np.array, x: np.array, **kwargs) -> float:
"""
Calculate entropy.

Parameters
----------
a: np.ndarray
Array to calculate entropy on. One sample at a time.
x: np.ndarray
Array to compute shape.
kwargs: optional
Keyword arguments.

Returns
-------
float:
A floating point, raning [0, inf].
"""
assert (a >= 0).all(), "Entropy computation requires non-negative values."

if len(x.shape) == 1:
newshape = np.prod(x.shape)
else:
newshape = np.prod(x.shape[1:])

a_reshaped = np.reshape(a, int(newshape))
a_normalised = a_reshaped.astype(np.float64) / np.sum(np.abs(a_reshaped))
return scipy.stats.entropy(pk=a_normalised)


def gini_coeffiient(a: np.array, x: np.array, **kwargs) -> float:
"""
Calculate Gini coefficient.

Parameters
----------
a: np.ndarray
Array to calculate gini_coeffiient on. One sample at a time.
x: np.ndarray
Array to compute shape.
kwargs: optional
Keyword arguments.

Returns
-------
float:
A floating point, ranging [0, 1].

"""

if len(x.shape) == 1:
newshape = np.prod(x.shape)
else:
newshape = np.prod(x.shape[1:])

a = np.array(np.reshape(a, newshape), dtype=np.float64)
a += 0.0000001
a = np.sort(a)
score = (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / (
a.shape[0] * np.sum(a)
)
return score


def discrete_entropy(a: np.array, x: np.array, **kwargs) -> float:
"""
Calculate discrete entropy of explanations with n_bins equidistant spaced bins
Parameters
----------
a: np.ndarray
Array to calculate entropy on. One sample at a time.
x: np.ndarray
Array to compute shape.
kwargs: optional
Keyword arguments.

n_bins: int
Number of bins. default is 100.

Returns
-------
float:
Discrete Entropy.
"""

n_bins = kwargs.get("n_bins", 100)

histogram, bins = np.histogram(a, bins=n_bins)

return scipy.stats.entropy(pk=histogram)
8 changes: 5 additions & 3 deletions quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def generate_tf_explanation(

if method in constants.DEPRECATED_XAI_METHODS_TF:
warnings.warn(
f"Explanaiton method string {method} is deprecated. Use "
f"Explanation method string {method} is deprecated. Use "
f"{constants.DEPRECATED_XAI_METHODS_TF[method]} instead.\n",
category=UserWarning,
)
Expand Down Expand Up @@ -416,7 +416,8 @@ def generate_tf_explanation(

else:
raise KeyError(
f"Specify a XAI method that already has been implemented {constants.AVAILABLE_XAI_METHODS_TF}."
f"To use the 'quantus.explain' method with tf-explain as a supporting library, "
f"specify a XAI method that is supported {constants.AVAILABLE_XAI_METHODS_TF}."
)

assert 0 not in reduce_axes, (
Expand All @@ -430,7 +431,8 @@ def generate_tf_explanation(

reduce_axes = {"axis": tuple(reduce_axes), "keepdims": keepdims}

# Prevent attribution summation for 2D-data. Recreate np.sum behavior when passing reduce_axes=(), i.e. no change.
# Prevent attribution summation for 2D-data.
# Recreate np.sum behavior when passing reduce_axes=(), i.e. no change.
if (len(tuple(reduce_axes)) == 0) | (explanation.ndim < 3):
return explanation

Expand Down
69 changes: 69 additions & 0 deletions quantus/functions/n_bins_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""This module holds a collection of algorithms to calculate a number of bins to use for entropy calculation."""

# This file is part of Quantus.
# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import scipy
import numpy as np


def freedman_diaconis_rule(a_batch: np.array) -> int:
"""Freedman–Diaconis' rule."""

iqr = np.percentile(a_batch, 75) - np.percentile(a_batch, 25)
n = a_batch[0].ndim
bin_width = 2 * iqr / np.power(n, 1 / 3)

# Set a minimum value for bin_width to avoid division by very small numbers.
min_bin_width = 1e-6
bin_width = max(bin_width, min_bin_width)

# Calculate number of bins based on bin width.
n_bins = int((np.max(a_batch) - np.min(a_batch)) / bin_width)

return n_bins


def scotts_rule(a_batch: np.array) -> int:
"""Scott's rule."""

std = np.std(a_batch)
n = a_batch[0].ndim

# Calculate bin width using Scott's rule.
bin_width = 3.5 * std / np.power(n, 1 / 3)

# Calculate number of bins based on bin width.
n_bins = int((np.max(a_batch) - np.min(a_batch)) / bin_width)

return n_bins


def square_root_choice(a_batch: np.array) -> int:
"""Square-root choice rule."""

n = a_batch[0].ndim
n_bins = int(np.sqrt(n))

return n_bins


def sturges_formula(a_batch: np.array) -> int:
"""Sturges' formula."""

n = a_batch[0].ndim
n_bins = int(np.log2(n) + 1)

return n_bins


def rice_rule(a_batch: np.array) -> int:
"""Rice Rule."""

n = a_batch[0].ndim
n_bins = int(2 * np.power(n, 1 / 3))

return n_bins
2 changes: 1 addition & 1 deletion quantus/functions/normalise_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def normalise_by_average_second_moment_estimate(
# Cast Sequence to tuple so numpy accepts it.
normalise_axes = tuple(normalise_axes)

# Check that square root of the second momment estimatte is nonzero.
# Check that square root of the second momment estimate is nonzero.
second_moment_sqrt = np.sqrt(
np.sum(a ** 2, axis=normalise_axes, keepdims=True)
/ np.prod([a.shape[n] for n in normalise_axes])
Expand Down
14 changes: 13 additions & 1 deletion quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from quantus.functions.normalise_func import *
from quantus.functions.perturb_func import *
from quantus.functions.similarity_func import *
from quantus.functions import n_bins_func
from quantus.metrics import *

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -61,7 +62,9 @@
"Effective Complexity": EffectiveComplexity,
},
"Randomisation": {
"Model Parameter Randomisation": ModelParameterRandomisation,
"MPRT": MPRT,
"Smooth MPRT": SmoothMPRT,
"Efficient MPRT": EfficientMPRT,
"Random Logit": RandomLogit,
},
"Axiomatic": {
Expand Down Expand Up @@ -157,6 +160,15 @@
}


AVAILABLE_N_BINS_ALGORITHMS = {
"Freedman Diaconis": n_bins_func.freedman_diaconis_rule,
"Scotts": n_bins_func.scotts_rule,
"Square Root": n_bins_func.square_root_choice,
"Sturges Formula": n_bins_func.sturges_formula,
"Rice": n_bins_func.rice_rule,
}


def available_categories() -> List[str]:
"""
Retrieve the available metric categories in Quantus.
Expand Down
31 changes: 29 additions & 2 deletions quantus/helpers/model/model_interface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""This model implements the basics for the ModelInterface class."""

# This file is part of Quantus.
# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import warnings
from importlib import util
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, List, Union, Generator, TypeVar, Generic

import numpy as np

if util.find_spec("tensorflow"):
import tensorflow as tf
if util.find_spec("torch"):
import torch

M = TypeVar("M")


Expand All @@ -20,7 +26,7 @@ class ModelInterface(ABC, Generic[M]):
def __init__(
self,
model: M,
channel_first: bool = True,
channel_first: Optional[bool] = True,
softmax: bool = False,
model_predict_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -191,3 +197,24 @@ def random_layer_generator_length(self) -> int:
Number of layers in model, which can be randomised.
"""
raise NotImplementedError

@property
def get_ml_framework_name(self) -> str:
"""
Identify the framework of the underlying model (PyTorch or TensorFlow).

Returns
-------
str
A string indicating the framework ('PyTorch', 'TensorFlow', or 'Unknown').
"""
if util.find_spec("torch"):
if isinstance(self.model, torch.nn.Module):
return "torch"
if util.find_spec("tensorflow"):
if isinstance(self.model, tf.keras.Model):
return "tensorflow"
else:
warnings.warn("Cannot identify ML framework of the given model.")
return "unknown"
return ""
2 changes: 1 addition & 1 deletion quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PyTorchModel(ModelInterface[nn.Module]):
def __init__(
self,
model: nn.Module,
channel_first: bool = True,
channel_first: bool = False,
softmax: bool = False,
model_predict_kwargs: Optional[Dict[str, Any]] = None,
device: Optional[str] = None,
Expand Down
12 changes: 6 additions & 6 deletions quantus/helpers/model/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,10 @@ class TensorFlowModel(ModelInterface[Model]):
def __init__(
self,
model: Model,
channel_first: bool = True,
channel_first: bool = False,
softmax: bool = False,
model_predict_kwargs: Optional[Dict[str, ...]] = None,
):
if model_predict_kwargs is None:
model_predict_kwargs = {}
# Disable progress bar while running inference on tf.keras.Model.
model_predict_kwargs["verbose"] = 0

"""
Initialisation of ModelInterface class.

Expand All @@ -64,6 +59,11 @@ def __init__(
model_predict_kwargs: dict, optional
Keyword arguments to be passed to the model's predict method.
"""
if model_predict_kwargs is None:
model_predict_kwargs = {}
# Disable progress bar while running inference on tf.keras.Model.
model_predict_kwargs["verbose"] = 0

super().__init__(
model=model,
channel_first=channel_first,
Expand Down
Loading