diff --git a/tests/explainers/wrappers/test_trak_wrapper.py b/tests/explainers/wrappers/test_trak_wrapper.py index b3101cc4..bc3acc3f 100644 --- a/tests/explainers/wrappers/test_trak_wrapper.py +++ b/tests/explainers/wrappers/test_trak_wrapper.py @@ -37,6 +37,39 @@ def test_trak_wrapper_explain_stateful( assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, + ), + ], +) +# TODO: I think a good naming convention is "test_..." or "test_...". +# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. +# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. +def test_trak_wrapper_explain_stateful_cache( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + + explainer = TRAK(model=model, cache_dir=tmp_path, train_dataset=dataset, **method_kwargs) + + explanations = explainer.explain(test=test_tensor, targets=test_labels) + test_tensor = torch.ones_like(test_tensor) + explanations_2 = explainer.explain(test=test_tensor, targets=test_labels) + assert not torch.allclose(explanations, explanations_2), "Caching is problematic inside the lifetime of the wrapper" + + @pytest.mark.explainers @pytest.mark.parametrize( "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", @@ -72,6 +105,49 @@ def test_trak_wrapper_explain_functional( assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, + ), + ], +) +def test_trak_wrapper_explain_functional_cache( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_first = trak_explain( + model=model, + cache_dir=str(tmp_path), + test_tensor=test_tensor, + train_dataset=dataset, + explanation_targets=test_labels, + device="cpu", + **method_kwargs, + ) + test_tensor = torch.rand_like(test_tensor) + explanations_second = trak_explain( + model=model, + cache_dir=str(tmp_path), + test_tensor=test_tensor, + train_dataset=dataset, + explanation_targets=test_labels, + device="cpu", + **method_kwargs, + ) + assert torch.allclose(explanations_first, explanations_second), "Caching is problematic between different instantiations" + + @pytest.mark.explainers @pytest.mark.parametrize( "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations",