-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9767b30
commit 3e98e9f
Showing
4 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
|
||
|
||
class ExplanationsAggregator(ABC): | ||
def __init__(self, training_size: int, *args, **kwargs): | ||
self.scores = torch.zeros(training_size) | ||
|
||
@abstractmethod | ||
def update(self, explanations: torch.Tensor): | ||
raise NotImplementedError | ||
|
||
def reset(self, *args, **kwargs): | ||
""" | ||
Used to reset the aggregator state. | ||
""" | ||
self.scores = torch.zeros_like(self.scores) | ||
|
||
def load_state_dict(self, state_dict: dict, *args, **kwargs): | ||
""" | ||
Used to load the aggregator state. | ||
""" | ||
self.scores = state_dict["scores"] | ||
|
||
def state_dict(self, *args, **kwargs): | ||
""" | ||
Used to return the metric state. | ||
""" | ||
return {"scores": self.scores} | ||
|
||
def compute(self) -> torch.Tensor: | ||
return self.scores.argsort() | ||
|
||
|
||
class SumAggregator(ExplanationsAggregator): | ||
def update(self, explanations: torch.Tensor) -> torch.Tensor: | ||
self.scores += explanations.sum(dim=0) | ||
|
||
|
||
class AbsSumAggregator(ExplanationsAggregator): | ||
def update(self, explanations: torch.Tensor) -> torch.Tensor: | ||
self.scores += explanations.abs().sum(dim=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import List, Union | ||
|
||
import torch | ||
from captum.influence import DataInfluence | ||
|
||
from src.explainers.base import Explainer | ||
|
||
|
||
class CaptumExplainerWrapper(Explainer): | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
model_id: str, | ||
train_dataset: torch.data.utils.Dataset, | ||
device: Union[str, torch.device], | ||
explainer_cls: DataInfluence, | ||
**explainer_kwargs, | ||
): | ||
super().__init__(model=model, model_id=model_id, train_dataset=train_dataset, device=device) | ||
self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_kwargs) | ||
|
||
def explain( | ||
self, test: torch.Tensor, targets: Union[torch.Tensor, List[int], None], **explainer_kwargs | ||
) -> torch.Tensor: | ||
if targets is not None: | ||
if not isinstance(targets, torch.Tensor): | ||
if isinstance(targets, list): | ||
targets = torch.tensor(targets) | ||
else: | ||
raise TypeError( | ||
f"targets should be of type NoneType, List or torch.Tensor. Got {type(targets)} instead." | ||
) | ||
return self.captum_explainer.influence(inputs=(test, targets), **explainer_kwargs) | ||
else: | ||
return self.captum_explainer.influence(inputs=test, **explainer_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Union | ||
|
||
import torch | ||
from captum.influence import SimilarityInfluence | ||
|
||
from src.explainers.captum.base import CaptumExplainerWrapper | ||
|
||
|
||
class CaptumSimilarityExplainer(CaptumExplainerWrapper): | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
model_id: str, | ||
train_dataset: torch.data.utils.Dataset, | ||
device: Union[str, torch.device], | ||
**explainer_kwargs, | ||
): | ||
super().__init__( | ||
model=model, | ||
model_id=model_id, | ||
train_dataset=train_dataset, | ||
device=device, | ||
explainer_cls=SimilarityInfluence, | ||
**explainer_kwargs, | ||
) | ||
|
||
def explain(self, test: torch.Tensor) -> torch.Tensor: | ||
return super().explain(test=test, targets=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
import torch | ||
|
||
from src.explainers.aggregators import AbsSumAggregator, SumAggregator | ||
|
||
|
||
@pytest.mark.aggregators | ||
@pytest.mark.parametrize( | ||
"test_id, dataset, explanations", | ||
[ | ||
( | ||
"mnist", | ||
"load_mnist_dataset", | ||
"load_mnist_explanations_1", | ||
), | ||
], | ||
) | ||
def test_sum_aggregator(test_id, dataset, explanations, request): | ||
dataset = request.getfixturevalue(dataset) | ||
explanations = request.getfixturevalue(explanations) | ||
aggregator = SumAggregator(training_size=len(dataset)) | ||
aggregator.update(explanations) | ||
global_rank = aggregator.compute() | ||
assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) | ||
|
||
|
||
@pytest.mark.aggregators | ||
@pytest.mark.parametrize( | ||
"test_id, dataset, explanations", | ||
[ | ||
( | ||
"mnist", | ||
"load_mnist_dataset", | ||
"load_mnist_explanations_1", | ||
), | ||
], | ||
) | ||
def test_abs_aggregator(test_id, dataset, explanations, request): | ||
dataset = request.getfixturevalue(dataset) | ||
explanations = request.getfixturevalue(explanations) | ||
aggregator = AbsSumAggregator(training_size=len(dataset)) | ||
aggregator.update(explanations) | ||
global_rank = aggregator.compute() | ||
assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort()) |