Skip to content

Commit

Permalink
Merge branch 'main' into add-tabular-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem authored Mar 1, 2024
2 parents 2ffbf92 + 279fe00 commit 7d9f5c0
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
22 changes: 9 additions & 13 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,15 @@ If you find this toolkit or its companion paper
interesting or useful in your research, please use the following Bibtex annotation to cite us:

```bibtex
@article{hedstrom2022quantus,
title={Quantus: An Explainable AI Toolkit for Responsible Evaluation of Neural Network Explanations},
author={Anna Hedström and
Leander Weber and
Dilyara Bareeva and
Franz Motzkus and
Wojciech Samek and
Sebastian Lapuschkin and
Marina M.-C. Höhne},
year={2022},
eprint={2202.06861},
archivePrefix={arXiv},
primaryClass={cs.LG}
@article{hedstrom2023quantus,
author = {Anna Hedstr{\"{o}}m and Leander Weber and Daniel Krakowczyk and Dilyara Bareeva and Franz Motzkus and Wojciech Samek and Sebastian Lapuschkin and Marina Marina M.{-}C. H{\"{o}}hne},
title = {Quantus: An Explainable AI Toolkit for Responsible Evaluation of Neural Network Explanations and Beyond},
journal = {Journal of Machine Learning Research},
year = {2023},
volume = {24},
number = {34},
pages = {1--11},
url = {http://jmlr.org/papers/v24/22-0142.html}
}
```

Expand Down
25 changes: 18 additions & 7 deletions quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import warnings
from importlib import util
from typing import Optional, Union
from typing import Optional, Union, Callable

import numpy as np
import quantus
Expand Down Expand Up @@ -391,7 +391,6 @@ def generate_tf_explanation(
)

elif method == "SmoothGrad":

num_samples = kwargs.get("num_samples", 5)
noise = kwargs.get("noise", 0.1)
explainer = tf_explain.core.smoothgrad.SmoothGrad()
Expand Down Expand Up @@ -513,6 +512,7 @@ def generate_captum_explanation(

if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs).to(device)
inputs.requires_grad_()

if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).to(device)
Expand Down Expand Up @@ -667,14 +667,22 @@ def f_reduce_axes(a):
elif method == "Control Var. Sobel Filter":
explanation = torch.zeros(size=inputs.shape)

if inputs.is_cuda:
inputs = inputs.cpu()

inputs_numpy = inputs.detach().numpy()

for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs[i].cpu().numpy()), 0, 1)
np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1)
)
explanation = explanation.mean(**reduce_axes)
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

elif method == "Control Var. Random Uniform":
explanation = torch.rand(size=(inputs.shape[0], *inputs.shape[2:]))
explanation = torch.rand(size=(inputs.shape))
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

elif method == "Control Var. Constant":
assert (
Expand All @@ -686,13 +694,16 @@ def f_reduce_axes(a):
# Update the tensor with values per input x.
for i in range(explanation.shape[0]):
constant_value = get_baseline_value(
value=kwargs["constant_value"], arr=inputs[i], return_shape=(1,)
value=kwargs["constant_value"],
arr=inputs[i],
return_shape=kwargs.get("return_shape", (1,)),
)[0]
explanation[i] = torch.Tensor().new_full(
size=explanation[0].shape, fill_value=constant_value
)

explanation = explanation.mean(**reduce_axes)
if len(explanation.shape) > 2:
explanation = explanation.mean(**reduce_axes)

else:
raise KeyError(
Expand Down
26 changes: 25 additions & 1 deletion tests/functions/test_explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@
},
{"shape": (8, 1, 28, 28)},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Sobel Filter",
},
{"min": -10000.0, "max": 10000.0},
),
(
lazy_fixture("load_1d_3ch_conv_model"),
lazy_fixture("almost_uniform_1d_no_abatch"),
Expand Down Expand Up @@ -334,6 +342,15 @@
},
{"value": 0.0},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Constant",
"constant_value": 0.0,
},
{"value": 0.0},
),
(
lazy_fixture("load_mnist_model"),
lazy_fixture("load_mnist_images"),
Expand All @@ -342,6 +359,14 @@
},
{"min": 0.0, "max": 1.0},
),
(
lazy_fixture("titanic_model_torch"),
lazy_fixture("titanic_dataset"),
{
"method": "Control Var. Random Uniform",
},
{"min": 0.0, "max": 1.0},
),
(
lazy_fixture("load_1d_3ch_conv_model"),
lazy_fixture("almost_uniform_1d_no_abatch"),
Expand Down Expand Up @@ -750,7 +775,6 @@ def test_explain_func(
params: dict,
expected: Union[float, dict, bool],
):

x_batch, y_batch = (data["x_batch"], data["y_batch"])
if "exception" in expected:
with pytest.raises(expected["exception"]):
Expand Down

0 comments on commit 7d9f5c0

Please sign in to comment.