diff --git a/docs/source/index.md b/docs/source/index.md index 4183f051..4a53ca63 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -73,19 +73,15 @@ If you find this toolkit or its companion paper interesting or useful in your research, please use the following Bibtex annotation to cite us: ```bibtex -@article{hedstrom2022quantus, - title={Quantus: An Explainable AI Toolkit for Responsible Evaluation of Neural Network Explanations}, - author={Anna Hedström and - Leander Weber and - Dilyara Bareeva and - Franz Motzkus and - Wojciech Samek and - Sebastian Lapuschkin and - Marina M.-C. Höhne}, - year={2022}, - eprint={2202.06861}, - archivePrefix={arXiv}, - primaryClass={cs.LG} +@article{hedstrom2023quantus, + author = {Anna Hedstr{\"{o}}m and Leander Weber and Daniel Krakowczyk and Dilyara Bareeva and Franz Motzkus and Wojciech Samek and Sebastian Lapuschkin and Marina Marina M.{-}C. H{\"{o}}hne}, + title = {Quantus: An Explainable AI Toolkit for Responsible Evaluation of Neural Network Explanations and Beyond}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {34}, + pages = {1--11}, + url = {http://jmlr.org/papers/v24/22-0142.html} } ``` diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index 7ac032ab..be2f6ca1 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -8,7 +8,7 @@ import warnings from importlib import util -from typing import Optional, Union +from typing import Optional, Union, Callable import numpy as np import quantus @@ -391,7 +391,6 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": - num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() @@ -513,6 +512,7 @@ def generate_captum_explanation( if not isinstance(inputs, torch.Tensor): inputs = torch.Tensor(inputs).to(device) + inputs.requires_grad_() if not isinstance(targets, torch.Tensor): targets = torch.as_tensor(targets).to(device) @@ -667,14 +667,22 @@ def f_reduce_axes(a): elif method == "Control Var. Sobel Filter": explanation = torch.zeros(size=inputs.shape) + if inputs.is_cuda: + inputs = inputs.cpu() + + inputs_numpy = inputs.detach().numpy() + for i in range(len(explanation)): explanation[i] = torch.Tensor( - np.clip(scipy.ndimage.sobel(inputs[i].cpu().numpy()), 0, 1) + np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1) ) - explanation = explanation.mean(**reduce_axes) + if len(explanation.shape) > 2: + explanation = explanation.mean(**reduce_axes) elif method == "Control Var. Random Uniform": - explanation = torch.rand(size=(inputs.shape[0], *inputs.shape[2:])) + explanation = torch.rand(size=(inputs.shape)) + if len(explanation.shape) > 2: + explanation = explanation.mean(**reduce_axes) elif method == "Control Var. Constant": assert ( @@ -686,13 +694,16 @@ def f_reduce_axes(a): # Update the tensor with values per input x. for i in range(explanation.shape[0]): constant_value = get_baseline_value( - value=kwargs["constant_value"], arr=inputs[i], return_shape=(1,) + value=kwargs["constant_value"], + arr=inputs[i], + return_shape=kwargs.get("return_shape", (1,)), )[0] explanation[i] = torch.Tensor().new_full( size=explanation[0].shape, fill_value=constant_value ) - explanation = explanation.mean(**reduce_axes) + if len(explanation.shape) > 2: + explanation = explanation.mean(**reduce_axes) else: raise KeyError( diff --git a/tests/functions/test_explanation_func.py b/tests/functions/test_explanation_func.py index 8123a3c5..27ec9f1f 100644 --- a/tests/functions/test_explanation_func.py +++ b/tests/functions/test_explanation_func.py @@ -300,6 +300,14 @@ }, {"shape": (8, 1, 28, 28)}, ), + ( + lazy_fixture("titanic_model_torch"), + lazy_fixture("titanic_dataset"), + { + "method": "Control Var. Sobel Filter", + }, + {"min": -10000.0, "max": 10000.0}, + ), ( lazy_fixture("load_1d_3ch_conv_model"), lazy_fixture("almost_uniform_1d_no_abatch"), @@ -334,6 +342,15 @@ }, {"value": 0.0}, ), + ( + lazy_fixture("titanic_model_torch"), + lazy_fixture("titanic_dataset"), + { + "method": "Control Var. Constant", + "constant_value": 0.0, + }, + {"value": 0.0}, + ), ( lazy_fixture("load_mnist_model"), lazy_fixture("load_mnist_images"), @@ -342,6 +359,14 @@ }, {"min": 0.0, "max": 1.0}, ), + ( + lazy_fixture("titanic_model_torch"), + lazy_fixture("titanic_dataset"), + { + "method": "Control Var. Random Uniform", + }, + {"min": 0.0, "max": 1.0}, + ), ( lazy_fixture("load_1d_3ch_conv_model"), lazy_fixture("almost_uniform_1d_no_abatch"), @@ -750,7 +775,6 @@ def test_explain_func( params: dict, expected: Union[float, dict, bool], ): - x_batch, y_batch = (data["x_batch"], data["y_batch"]) if "exception" in expected: with pytest.raises(expected["exception"]):