diff --git a/src/explainers/aggregators/aggregators.py b/src/explainers/aggregators/aggregators.py deleted file mode 100644 index 0d36b6b4..00000000 --- a/src/explainers/aggregators/aggregators.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - - -class ExplanationsAggregator(ABC): - def __init__(self, training_size: int, *args, **kwargs): - self.scores = torch.zeros(training_size) - - @abstractmethod - def update(self, explanations: torch.Tensor): - raise NotImplementedError - - def reset(self, *args, **kwargs): - """ - Used to reset the aggregator state. - """ - self.scores = torch.zeros_like(self.scores) - - def load_state_dict(self, state_dict: dict, *args, **kwargs): - """ - Used to load the aggregator state. - """ - self.scores = state_dict["scores"] - - def state_dict(self, *args, **kwargs): - """ - Used to return the metric state. - """ - return {"scores": self.scores} - - def compute(self) -> torch.Tensor: - return self.scores.argsort() - - -class SumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: - self.scores += explanations.sum(dim=0) - - -class AbsSumAggregator(ExplanationsAggregator): - def update(self, explanations: torch.Tensor) -> torch.Tensor: - self.scores += explanations.abs().sum(dim=0) diff --git a/src/explainers/base.py b/src/explainers/base.py index 5a8e9f69..d29123e0 100644 --- a/src/explainers/base.py +++ b/src/explainers/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -23,7 +23,7 @@ def __init__( self.model.to(dev) @abstractmethod - def explain(self, test: torch.Tensor, **kwargs): + def explain(self, test: torch.Tensor, targets: Union[List[int], torch.Tensor, None], **kwargs): raise NotImplementedError @abstractmethod diff --git a/tests/explainers/aggregators/test_aggregators.py b/tests/explainers/aggregators/test_aggregators.py deleted file mode 100644 index cfab2355..00000000 --- a/tests/explainers/aggregators/test_aggregators.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from src.explainers.aggregators.aggregators import ( - AbsSumAggregator, - SumAggregator, -) - - -@pytest.mark.aggregators -@pytest.mark.parametrize( - "test_id, dataset, explanations", - [ - ( - "mnist", - "load_mnist_dataset", - "load_mnist_explanations_1", - ), - ], -) -def test_sum_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) - explanations = request.getfixturevalue(explanations) - aggregator = SumAggregator(training_size=len(dataset)) - aggregator.update(explanations) - global_rank = aggregator.compute() - assert torch.allclose(global_rank, explanations.sum(dim=0).argsort()) - - -@pytest.mark.aggregators -@pytest.mark.parametrize( - "test_id, dataset, explanations", - [ - ( - "mnist", - "load_mnist_dataset", - "load_mnist_explanations_1", - ), - ], -) -def test_abs_aggregator(test_id, dataset, explanations, request): - dataset = request.getfixturevalue(dataset) - explanations = request.getfixturevalue(explanations) - aggregator = AbsSumAggregator(training_size=len(dataset)) - aggregator.update(explanations) - global_rank = aggregator.compute() - assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort())