Skip to content

Commit

Permalink
add test for caching mishaps
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Aug 16, 2024
1 parent 09ad3a5 commit 0edf8ff
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions tests/explainers/wrappers/test_trak_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,39 @@ def test_trak_wrapper_explain_stateful(
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, test_labels, method_kwargs",
[
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"},
),
],
)
# TODO: I think a good naming convention is "test_<function_name>..." or "test_<class_name>...".
# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class.
# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something.
def test_trak_wrapper_explain_stateful_cache(
test_id, model, dataset, test_tensor, test_labels, method_kwargs, request, tmp_path
):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
test_labels = request.getfixturevalue(test_labels)

explainer = TRAK(model=model, cache_dir=tmp_path, train_dataset=dataset, **method_kwargs)

explanations = explainer.explain(test=test_tensor, targets=test_labels)
test_tensor = torch.ones_like(test_tensor)
explanations_2 = explainer.explain(test=test_tensor, targets=test_labels)
assert not torch.allclose(explanations, explanations_2), "Caching is problematic inside the lifetime of the wrapper"


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations",
Expand Down Expand Up @@ -72,6 +105,49 @@ def test_trak_wrapper_explain_functional(
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, test_labels, method_kwargs",
[
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"},
),
],
)
def test_trak_wrapper_explain_functional_cache(
test_id, model, dataset, test_tensor, test_labels, method_kwargs, request, tmp_path
):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
test_labels = request.getfixturevalue(test_labels)
explanations_first = trak_explain(
model=model,
cache_dir=str(tmp_path),
test_tensor=test_tensor,
train_dataset=dataset,
explanation_targets=test_labels,
device="cpu",
**method_kwargs,
)
test_tensor = torch.rand_like(test_tensor)
explanations_second = trak_explain(
model=model,
cache_dir=str(tmp_path),
test_tensor=test_tensor,
train_dataset=dataset,
explanation_targets=test_labels,
device="cpu",
**method_kwargs,
)
assert torch.allclose(explanations_first, explanations_second), "Caching is problematic between different instantiations"


@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations",
Expand Down

0 comments on commit 0edf8ff

Please sign in to comment.