diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77e1c2ae..958a61b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml + - id: check-added-large-files - repo: local hooks: diff --git a/pytest.ini b/pytest.ini index 2805543f..a2b7d997 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,5 @@ markers = utils: utils files explainers: explainers + localization_metrics: localization_metrics + unnamed_metrics: unnamed_metrics diff --git a/src/explainers/explain_wrapper.py b/src/explainers/explain_wrapper.py index b48d8c94..f426d5b8 100644 --- a/src/explainers/explain_wrapper.py +++ b/src/explainers/explain_wrapper.py @@ -41,5 +41,7 @@ def explain( similarity_direction=sim_direction, batch_size=batch_size, ) + topk_idx, topk_val = sim_influence.influence(test_tensor, len(train_dataset))[layer] + tda = torch.gather(topk_val, 1, topk_idx) - return sim_influence.influence(test_tensor, top_k)[layer] + return tda diff --git a/src/metrics/base.py b/src/metrics/base.py index 52852a3b..6c09cabf 100644 --- a/src/metrics/base.py +++ b/src/metrics/base.py @@ -13,7 +13,7 @@ def __init__(self, device, *args, **kwargs): @abstractmethod def __call__( self, - explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations], + *args, **kwargs, ): """ diff --git a/src/metrics/localization/identical_class.py b/src/metrics/localization/identical_class.py index 66187e8a..38cdfbda 100644 --- a/src/metrics/localization/identical_class.py +++ b/src/metrics/localization/identical_class.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Union import torch @@ -16,49 +17,70 @@ def __init__(self, device, *args, **kwargs): def __call__( self, - test_predictions: torch.Tensor, - batch_size: int = 1, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + test_dataset: torch.utils.data.Dataset, explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", + batch_size: Optional[int] = 8, **kwargs, ): """ :param test_predictions: :param explanations: - :param batch_size: + :param saved_explanations_batch_size: :param kwargs: :return: """ if isinstance(explanations, str): explanations = EC.load(path=explanations, device=self.device) + if explanations.batch_size != batch_size: + warnings.warn( + "Batch size mismatch between loaded explanations and passed batch size. The inferred batch " + "size will be used instead." + ) + batch_size = explanations[0] elif isinstance(explanations, torch.Tensor): - explanations = TensorExplanations(explanations) - - # assert len(test_dataset) == len(explanations) - assert test_predictions.shape[0] == batch_size * len( - explanations - ), f"Length of test predictions {test_predictions.shape[0]} and explanations {len(explanations)} do not match" + explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device) scores = [] - for i in range(test_predictions.shape[0] // batch_size + 1): + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + for i, data in enumerate(test_dataloader): + if isinstance(data, tuple): + data = data[0] + assert data.shape[0] == explanations[i].shape[0], ( + f"Batch size mismatch between explanations and input samples: " + f"{data.shape[0]} != {explanations[i].shape[0]} for batch {i}." + ) score = self._evaluate_instance( - test_labels=test_predictions[i * batch_size : i * batch_size + 1], + model=model, + train_dataset=train_dataset, + x_batch=data, xpl=explanations[i], ) scores.append(score) - return {"score": torch.tensor(scores).mean()} + return {"score": torch.cat(scores).mean()} def _evaluate_instance( self, - test_labels: torch.Tensor, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + x_batch: torch.Tensor, xpl: torch.Tensor, ): """ Used to implement metric-specific logic. """ - top_one_xpl_labels = xpl.argmax(dim=1) + top_one_xpl_indices = xpl.argmax(dim=1) + top_one_xpl_samples = torch.stack([train_dataset[i][0] for i in top_one_xpl_indices]) + + test_output = model(x_batch.to(self.device)) + test_pred = test_output.argmax(dim=1) + + top_one_xpl_output = model(top_one_xpl_samples.to(self.device)) + top_one_xpl_pred = top_one_xpl_output.argmax(dim=1) - return (test_labels == top_one_xpl_labels) * 1.0 + return (test_pred == top_one_xpl_pred) * 1.0 diff --git a/src/metrics/unnamed/top_k_overlap.py b/src/metrics/unnamed/top_k_overlap.py index 45283d16..b01a7dc8 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/src/metrics/unnamed/top_k_overlap.py @@ -1,3 +1,4 @@ +import warnings from collections import Counter from typing import Optional, Union @@ -17,10 +18,11 @@ def __init__(self, device, *args, **kwargs): def __call__( self, - test_logits: torch.Tensor, - batch_size: int = 1, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, top_k: int = 1, explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", + batch_size: Optional[int] = 8, **kwargs, ): """ @@ -34,44 +36,38 @@ def __call__( if isinstance(explanations, str): explanations = EC.load(path=explanations, device=self.device) + if explanations.batch_size != batch_size: + warnings.warn( + "Batch size mismatch between loaded explanations and passed batch size. The inferred batch " + "size will be used instead." + ) + batch_size = explanations[0] elif isinstance(explanations, torch.Tensor): - explanations = TensorExplanations(explanations) - - # assert len(test_dataset) == len(explanations) - assert test_logits.shape[0] == batch_size * len( - explanations - ), f"Length of test logits {test_logits.shape[0]} and explanations {len(explanations)} do not match" + explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device) all_top_k_examples = [] - all_top_k_probs = [] - for i in range(test_logits.shape[0] // batch_size + 1): - top_k_examples, top_k_probs = self._evaluate_instance( - test_logits=test_logits[i * batch_size : i * batch_size + 1], + + for i in range(len(explanations)): + top_k_examples = self._evaluate_instance( xpl=explanations[i], + top_k=top_k, ) all_top_k_examples += top_k_examples - all_top_k_probs += top_k_probs - all_top_k_probs = torch.stack(all_top_k_probs) # calculate the cardinality of the set of top-k examples cardinality = len(set(all_top_k_examples)) - # find the index of the first occurence of the top-k examples - indices = [all_top_k_examples.index(ex) for ex in set(all_top_k_examples)] - # calculate the probability of the set of top-k examples - probability = all_top_k_probs[indices].mean() - return {"cardinality": cardinality, "probability": probability} + # TODO: calculate the probability of the set of top-k examples + return {"score": cardinality} def _evaluate_instance( self, - test_logits: torch.Tensor, xpl: torch.Tensor, top_k: int = 1, ): """ Used to implement metric-specific logic. """ - top_k_examples = torch.topk(xpl.flatten(), top_k).indices - top_k_probs = torch.softmax(test_logits, dim=1)[top_k_examples] - return top_k_examples, top_k_probs + top_k_indices = torch.topk(xpl, top_k).indices + return top_k_indices diff --git a/src/utils/explanations.py b/src/utils/explanations.py index 91d4dc25..4c97190c 100644 --- a/src/utils/explanations.py +++ b/src/utils/explanations.py @@ -11,6 +11,14 @@ def __init__( *args, **kwargs, ): + """ + + Exaplanations interface class. Used to define the interface for the Explanations classes. + Each explanation class implements __getitem__, __setitem__, and __len__ methods, whereby an "item" is a + explanation tensor batch. + :param args: + :param kwargs: + """ pass def __getitem__(self, index: Union[int, slice]) -> torch.Tensor: @@ -27,11 +35,12 @@ class TensorExplanations(Explanations): def __init__( self, tensor: torch.Tensor, + batch_size: Optional[int] = 8, device: str = "cpu", ): """ Returns explanations from cache saved as tensors. __getitem__ and __setitem__ methods are used to access the - explanations on per-sample basis. + explanations on a batch basis. :param dataset_id: :param top_k: @@ -40,10 +49,10 @@ def __init__( super().__init__() self.device = device self.xpl = tensor.to(self.device) + self.batch_size = batch_size # assert the number of explanation dimensions is 2 and insert extra dimension to emulate batching assert len(self.xpl.shape) == 2, "Explanations object has more than 2 dimensions." - self.xpl = self.xpl.unsqueeze(1) def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor: """ @@ -51,7 +60,7 @@ def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor: :param idx: :return: """ - return self.xpl[idx] + return self.xpl[idx * self.batch_size : min((idx + 1) * self.batch_size, self.xpl.shape[0])] def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Tensor]): """ @@ -61,18 +70,17 @@ def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Ten :return: """ - self.xpl[idx] = val + self.xpl[idx * self.batch_size : (idx + 1) * self.batch_size] = val return val def __len__(self) -> int: - return self.xpl.shape[0] + return int(self.xpl.shape[0] // self.batch_size) + 1 class BatchedCachedExplanations(Explanations): def __init__( self, cache_dir: str = "./batch_wise_cached_explanations", - batch_size: Optional[int] = None, device: str = "cpu", ): """ @@ -84,12 +92,12 @@ def __init__( :param cache_dir: """ super().__init__() - self.batch_size = batch_size self.cache_dir = cache_dir self.device = device self.av_filesearch = os.path.join(cache_dir, "*.pt") self.files = glob.glob(self.av_filesearch) + self.batch_size = self[0].shape[0] def __getitem__(self, idx: int) -> torch.Tensor: """ diff --git a/tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt b/tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt new file mode 100644 index 00000000..50e65d6e Binary files /dev/null and b/tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt differ diff --git a/tests/assets/mnist_test_suite_1/test_dataset.pt b/tests/assets/mnist_test_suite_1/test_dataset.pt new file mode 100644 index 00000000..cb13a259 Binary files /dev/null and b/tests/assets/mnist_test_suite_1/test_dataset.pt differ diff --git a/tests/conftest.py b/tests/conftest.py index e1b8e231..f9e15f96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,4 @@ def load_mnist_test_samples_1(): @pytest.fixture() def load_mnist_explanations_1(): - rankings = torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_ranking.pt") - tda = torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") - return rankings, tda + return torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") diff --git a/tests/explainers/test_explain_wrapper.py b/tests/explainers/test_explain_wrapper.py index f2f6553f..fdb114d5 100644 --- a/tests/explainers/test_explain_wrapper.py +++ b/tests/explainers/test_explain_wrapper.py @@ -25,8 +25,8 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) - ranking_exp, tda_exp = request.getfixturevalue(explanations) - ranking, tda = explain( + tda_exp = request.getfixturevalue(explanations) + tda = explain( model, test_id, os.path.join("./cache", "test_id"), @@ -35,5 +35,4 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met method, **method_kwargs, ) - assert torch.allclose(ranking, ranking_exp), "Explanation rankings are not as expected" assert torch.allclose(tda, tda_exp), "Training data attributions are not as expected" diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index b0e1cb1d..9ad27a7c 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -3,16 +3,26 @@ from metrics.localization.identical_class import IdenticalClass -@pytest.mark.utils +@pytest.mark.localization_metrics @pytest.mark.parametrize( - "test_prediction, explanations", + "test_id, model, dataset, test_tensor, batch_size, explanations", [ - ("load_rand_test_predictions", "load_rand_tensor_explanations"), + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + 8, + "load_mnist_explanations_1", + ), ], ) -def test_identical_class_metrics(test_prediction, explanations, request): - test_prediction = request.getfixturevalue(test_prediction) - explanations = request.getfixturevalue(explanations) +def test_identical_class_metrics(test_id, model, dataset, test_tensor, batch_size, explanations, request): + model = request.getfixturevalue(model) + test_tensor = request.getfixturevalue(test_tensor) + dataset = request.getfixturevalue(dataset) + tda = request.getfixturevalue(explanations) metric = IdenticalClass(device="cpu") - score = metric(test_prediction, explanations)["score"] - assert score > 0 + score = metric(model=model, train_dataset=dataset, test_dataset=test_tensor, explanations=tda)["score"] + # TODO: introduce a more meaningfull test, where the score is not zero + assert score == 0 diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index e69de29b..7c4458ed 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -0,0 +1,20 @@ +import pytest + +from metrics.unnamed.top_k_overlap import TopKOverlap + + +@pytest.mark.unnamed_metrics +@pytest.mark.parametrize( + "test_id, model, dataset, top_k, batch_size, explanations", + [ + ("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1"), + ], +) +def test_top_k_overlap_metrics(test_id, model, dataset, top_k, batch_size, explanations, request): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + tda = request.getfixturevalue(explanations) + metric = TopKOverlap(device="cpu") + score = metric(model=model, train_dataset=dataset, top_k=top_k, explanations=tda, batch_size=batch_size)["score"] + + assert score == 10 diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py deleted file mode 100644 index 682fd2d7..00000000 --- a/tests/utils/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest -import torch - - -@pytest.fixture() -def load_dataset(): - x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)]) - y = torch.tensor([0, 1, 0]).long() - return torch.utils.data.TensorDataset(x, y)