Skip to content

Commit

Permalink
Last commit actually here
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 14, 2024
1 parent 9767b30 commit 3e98e9f
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/explainers/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod

import torch


class ExplanationsAggregator(ABC):
def __init__(self, training_size: int, *args, **kwargs):
self.scores = torch.zeros(training_size)

@abstractmethod
def update(self, explanations: torch.Tensor):
raise NotImplementedError

def reset(self, *args, **kwargs):
"""
Used to reset the aggregator state.
"""
self.scores = torch.zeros_like(self.scores)

def load_state_dict(self, state_dict: dict, *args, **kwargs):
"""
Used to load the aggregator state.
"""
self.scores = state_dict["scores"]

def state_dict(self, *args, **kwargs):
"""
Used to return the metric state.
"""
return {"scores": self.scores}

def compute(self) -> torch.Tensor:
return self.scores.argsort()


class SumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.sum(dim=0)


class AbsSumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.abs().sum(dim=0)
35 changes: 35 additions & 0 deletions src/explainers/captum/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List, Union

import torch
from captum.influence import DataInfluence

from src.explainers.base import Explainer


class CaptumExplainerWrapper(Explainer):
def __init__(
self,
model: torch.nn.Module,
model_id: str,
train_dataset: torch.data.utils.Dataset,
device: Union[str, torch.device],
explainer_cls: DataInfluence,
**explainer_kwargs,
):
super().__init__(model=model, model_id=model_id, train_dataset=train_dataset, device=device)
self.captum_explainer = explainer_cls(model=model, train_dataset=train_dataset, **explainer_kwargs)

def explain(
self, test: torch.Tensor, targets: Union[torch.Tensor, List[int], None], **explainer_kwargs
) -> torch.Tensor:
if targets is not None:
if not isinstance(targets, torch.Tensor):
if isinstance(targets, list):
targets = torch.tensor(targets)
else:
raise TypeError(
f"targets should be of type NoneType, List or torch.Tensor. Got {type(targets)} instead."
)
return self.captum_explainer.influence(inputs=(test, targets), **explainer_kwargs)
else:
return self.captum_explainer.influence(inputs=test, **explainer_kwargs)
28 changes: 28 additions & 0 deletions src/explainers/captum/similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Union

import torch
from captum.influence import SimilarityInfluence

from src.explainers.captum.base import CaptumExplainerWrapper


class CaptumSimilarityExplainer(CaptumExplainerWrapper):
def __init__(
self,
model: torch.nn.Module,
model_id: str,
train_dataset: torch.data.utils.Dataset,
device: Union[str, torch.device],
**explainer_kwargs,
):
super().__init__(
model=model,
model_id=model_id,
train_dataset=train_dataset,
device=device,
explainer_cls=SimilarityInfluence,
**explainer_kwargs,
)

def explain(self, test: torch.Tensor) -> torch.Tensor:
return super().explain(test=test, targets=None)
44 changes: 44 additions & 0 deletions tests/explainers/test_aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch

from src.explainers.aggregators import AbsSumAggregator, SumAggregator


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_sum_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = SumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.sum(dim=0).argsort())


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_abs_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = AbsSumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort())

0 comments on commit 3e98e9f

Please sign in to comment.