Skip to content

Commit

Permalink
add unit tests top_k_overlap.py and identical_class.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed May 13, 2024
1 parent e808089 commit c68451c
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 70 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: local
hooks:
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
markers =
utils: utils files
explainers: explainers
localization_metrics: localization_metrics
unnamed_metrics: unnamed_metrics
4 changes: 3 additions & 1 deletion src/explainers/explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@ def explain(
similarity_direction=sim_direction,
batch_size=batch_size,
)
topk_idx, topk_val = sim_influence.influence(test_tensor, len(train_dataset))[layer]
tda = torch.gather(topk_val, 1, topk_idx)

return sim_influence.influence(test_tensor, top_k)[layer]
return tda
2 changes: 1 addition & 1 deletion src/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, device, *args, **kwargs):
@abstractmethod
def __call__(
self,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations],
*args,
**kwargs,
):
"""
Expand Down
52 changes: 37 additions & 15 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional, Union

import torch
Expand All @@ -16,49 +17,70 @@ def __init__(self, device, *args, **kwargs):

def __call__(
self,
test_predictions: torch.Tensor,
batch_size: int = 1,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
**kwargs,
):
"""
:param test_predictions:
:param explanations:
:param batch_size:
:param saved_explanations_batch_size:
:param kwargs:
:return:
"""

if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=self.device)
if explanations.batch_size != batch_size:
warnings.warn(
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch "
"size will be used instead."
)
batch_size = explanations[0]
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations)

# assert len(test_dataset) == len(explanations)
assert test_predictions.shape[0] == batch_size * len(
explanations
), f"Length of test predictions {test_predictions.shape[0]} and explanations {len(explanations)} do not match"
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device)

scores = []
for i in range(test_predictions.shape[0] // batch_size + 1):
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
for i, data in enumerate(test_dataloader):
if isinstance(data, tuple):
data = data[0]
assert data.shape[0] == explanations[i].shape[0], (
f"Batch size mismatch between explanations and input samples: "
f"{data.shape[0]} != {explanations[i].shape[0]} for batch {i}."
)
score = self._evaluate_instance(
test_labels=test_predictions[i * batch_size : i * batch_size + 1],
model=model,
train_dataset=train_dataset,
x_batch=data,
xpl=explanations[i],
)
scores.append(score)

return {"score": torch.tensor(scores).mean()}
return {"score": torch.cat(scores).mean()}

def _evaluate_instance(
self,
test_labels: torch.Tensor,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
x_batch: torch.Tensor,
xpl: torch.Tensor,
):
"""
Used to implement metric-specific logic.
"""

top_one_xpl_labels = xpl.argmax(dim=1)
top_one_xpl_indices = xpl.argmax(dim=1)
top_one_xpl_samples = torch.stack([train_dataset[i][0] for i in top_one_xpl_indices])

test_output = model(x_batch.to(self.device))
test_pred = test_output.argmax(dim=1)

top_one_xpl_output = model(top_one_xpl_samples.to(self.device))
top_one_xpl_pred = top_one_xpl_output.argmax(dim=1)

return (test_labels == top_one_xpl_labels) * 1.0
return (test_pred == top_one_xpl_pred) * 1.0
42 changes: 19 additions & 23 deletions src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections import Counter
from typing import Optional, Union

Expand All @@ -17,10 +18,11 @@ def __init__(self, device, *args, **kwargs):

def __call__(
self,
test_logits: torch.Tensor,
batch_size: int = 1,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
top_k: int = 1,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
**kwargs,
):
"""
Expand All @@ -34,44 +36,38 @@ def __call__(

if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=self.device)
if explanations.batch_size != batch_size:
warnings.warn(
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch "
"size will be used instead."
)
batch_size = explanations[0]
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations)

# assert len(test_dataset) == len(explanations)
assert test_logits.shape[0] == batch_size * len(
explanations
), f"Length of test logits {test_logits.shape[0]} and explanations {len(explanations)} do not match"
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device)

all_top_k_examples = []
all_top_k_probs = []
for i in range(test_logits.shape[0] // batch_size + 1):
top_k_examples, top_k_probs = self._evaluate_instance(
test_logits=test_logits[i * batch_size : i * batch_size + 1],

for i in range(len(explanations)):
top_k_examples = self._evaluate_instance(
xpl=explanations[i],
top_k=top_k,
)
all_top_k_examples += top_k_examples
all_top_k_probs += top_k_probs

all_top_k_probs = torch.stack(all_top_k_probs)
# calculate the cardinality of the set of top-k examples
cardinality = len(set(all_top_k_examples))
# find the index of the first occurence of the top-k examples
indices = [all_top_k_examples.index(ex) for ex in set(all_top_k_examples)]
# calculate the probability of the set of top-k examples
probability = all_top_k_probs[indices].mean()

return {"cardinality": cardinality, "probability": probability}
# TODO: calculate the probability of the set of top-k examples
return {"score": cardinality}

def _evaluate_instance(
self,
test_logits: torch.Tensor,
xpl: torch.Tensor,
top_k: int = 1,
):
"""
Used to implement metric-specific logic.
"""
top_k_examples = torch.topk(xpl.flatten(), top_k).indices
top_k_probs = torch.softmax(test_logits, dim=1)[top_k_examples]

return top_k_examples, top_k_probs
top_k_indices = torch.topk(xpl, top_k).indices
return top_k_indices
22 changes: 15 additions & 7 deletions src/utils/explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def __init__(
*args,
**kwargs,
):
"""
Exaplanations interface class. Used to define the interface for the Explanations classes.
Each explanation class implements __getitem__, __setitem__, and __len__ methods, whereby an "item" is a
explanation tensor batch.
:param args:
:param kwargs:
"""
pass

def __getitem__(self, index: Union[int, slice]) -> torch.Tensor:
Expand All @@ -27,11 +35,12 @@ class TensorExplanations(Explanations):
def __init__(
self,
tensor: torch.Tensor,
batch_size: Optional[int] = 8,
device: str = "cpu",
):
"""
Returns explanations from cache saved as tensors. __getitem__ and __setitem__ methods are used to access the
explanations on per-sample basis.
explanations on a batch basis.
:param dataset_id:
:param top_k:
Expand All @@ -40,18 +49,18 @@ def __init__(
super().__init__()
self.device = device
self.xpl = tensor.to(self.device)
self.batch_size = batch_size

# assert the number of explanation dimensions is 2 and insert extra dimension to emulate batching
assert len(self.xpl.shape) == 2, "Explanations object has more than 2 dimensions."
self.xpl = self.xpl.unsqueeze(1)

def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor:
"""
:param idx:
:return:
"""
return self.xpl[idx]
return self.xpl[idx * self.batch_size : min((idx + 1) * self.batch_size, self.xpl.shape[0])]

def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Tensor]):
"""
Expand All @@ -61,18 +70,17 @@ def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Ten
:return:
"""

self.xpl[idx] = val
self.xpl[idx * self.batch_size : (idx + 1) * self.batch_size] = val
return val

def __len__(self) -> int:
return self.xpl.shape[0]
return int(self.xpl.shape[0] // self.batch_size) + 1


class BatchedCachedExplanations(Explanations):
def __init__(
self,
cache_dir: str = "./batch_wise_cached_explanations",
batch_size: Optional[int] = None,
device: str = "cpu",
):
"""
Expand All @@ -84,12 +92,12 @@ def __init__(
:param cache_dir:
"""
super().__init__()
self.batch_size = batch_size
self.cache_dir = cache_dir
self.device = device

self.av_filesearch = os.path.join(cache_dir, "*.pt")
self.files = glob.glob(self.av_filesearch)
self.batch_size = self[0].shape[0]

def __getitem__(self, idx: int) -> torch.Tensor:
"""
Expand Down
Binary file not shown.
Binary file added tests/assets/mnist_test_suite_1/test_dataset.pt
Binary file not shown.
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,4 @@ def load_mnist_test_samples_1():

@pytest.fixture()
def load_mnist_explanations_1():
rankings = torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_ranking.pt")
tda = torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt")
return rankings, tda
return torch.load(f"tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt")
5 changes: 2 additions & 3 deletions tests/explainers/test_explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
ranking_exp, tda_exp = request.getfixturevalue(explanations)
ranking, tda = explain(
tda_exp = request.getfixturevalue(explanations)
tda = explain(
model,
test_id,
os.path.join("./cache", "test_id"),
Expand All @@ -35,5 +35,4 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met
method,
**method_kwargs,
)
assert torch.allclose(ranking, ranking_exp), "Explanation rankings are not as expected"
assert torch.allclose(tda, tda_exp), "Training data attributions are not as expected"
26 changes: 18 additions & 8 deletions tests/metrics/test_localization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,26 @@
from metrics.localization.identical_class import IdenticalClass


@pytest.mark.utils
@pytest.mark.localization_metrics
@pytest.mark.parametrize(
"test_prediction, explanations",
"test_id, model, dataset, test_tensor, batch_size, explanations",
[
("load_rand_test_predictions", "load_rand_tensor_explanations"),
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
"load_mnist_explanations_1",
),
],
)
def test_identical_class_metrics(test_prediction, explanations, request):
test_prediction = request.getfixturevalue(test_prediction)
explanations = request.getfixturevalue(explanations)
def test_identical_class_metrics(test_id, model, dataset, test_tensor, batch_size, explanations, request):
model = request.getfixturevalue(model)
test_tensor = request.getfixturevalue(test_tensor)
dataset = request.getfixturevalue(dataset)
tda = request.getfixturevalue(explanations)
metric = IdenticalClass(device="cpu")
score = metric(test_prediction, explanations)["score"]
assert score > 0
score = metric(model=model, train_dataset=dataset, test_dataset=test_tensor, explanations=tda)["score"]
# TODO: introduce a more meaningfull test, where the score is not zero
assert score == 0
20 changes: 20 additions & 0 deletions tests/metrics/test_unnamed_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from metrics.unnamed.top_k_overlap import TopKOverlap


@pytest.mark.unnamed_metrics
@pytest.mark.parametrize(
"test_id, model, dataset, top_k, batch_size, explanations",
[
("mnist", "load_mnist_model", "load_mnist_dataset", 3, 8, "load_mnist_explanations_1"),
],
)
def test_top_k_overlap_metrics(test_id, model, dataset, top_k, batch_size, explanations, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
tda = request.getfixturevalue(explanations)
metric = TopKOverlap(device="cpu")
score = metric(model=model, train_dataset=dataset, top_k=top_k, explanations=tda, batch_size=batch_size)["score"]

assert score == 10
9 changes: 0 additions & 9 deletions tests/utils/conftest.py

This file was deleted.

0 comments on commit c68451c

Please sign in to comment.