From 802a5c21f1e00f33f9472752c5fe66005877edee Mon Sep 17 00:00:00 2001 From: Dilyara Bareeva Date: Fri, 21 Jun 2024 11:07:07 +0200 Subject: [PATCH] using tmp_path fixture for tests --- .../wrappers/test_captum_influence.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index e88c353d..c0a1073b 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -1,5 +1,4 @@ import os -import shutil from collections import OrderedDict import pytest @@ -28,7 +27,7 @@ ], ) # TODO: I think a good naming convention is "test_..." or "test_...". -def test_self_influence(test_id, init_kwargs, request): +def test_self_influence(test_id, init_kwargs, tmp_path): # TODO: this should be a fixture. model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())])) @@ -38,11 +37,12 @@ def test_self_influence(test_id, init_kwargs, request): y = torch.randint(0, 10, (100,)) rand_dataset = TensorDataset(X, y) + # Using tmp_path pytest fixtures to create a temporary directory # TODO: One test should test one thing. This is test 1, .... self_influence_rank_functional = captum_similarity_self_influence( model=model, model_id="0", - cache_dir="temp_captum", + cache_dir=str(tmp_path), train_dataset=rand_dataset, init_kwargs=init_kwargs, device="cpu", @@ -53,7 +53,7 @@ def test_self_influence(test_id, init_kwargs, request): explainer_obj = CaptumSimilarity( model=model, model_id="1", - cache_dir="temp_captum2", + cache_dir=str(tmp_path), train_dataset=rand_dataset, device="cpu", **init_kwargs, @@ -63,13 +63,6 @@ def test_self_influence(test_id, init_kwargs, request): # TODO: here we then specifically test self_influence for CaptumSimilarity and should make it explicit in the name. self_influence_rank_stateful = explainer_obj.self_influence() - # TODO: we check "temp_captum2" but then remove os.path.join(os.getcwd(), "temp_captum2")? - # TODO: is there a reason to fear that the "temp_captum2" folder is not in os.getcwd()? - if os.path.isdir("temp_captum2"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2")) - if os.path.isdir("temp_captum"): - shutil.rmtree(os.path.join(os.getcwd(), "temp_captum")) - # TODO: what if we pass a non-identity model? Then we don't expect torch.linalg.norm(X, dim=-1).argsort() # TODO: let's put expectations in the parametrisation of tests. We want to test different scenarios, # and not some super-specific case. This specific case definitely can be tested as well. @@ -96,7 +89,9 @@ def test_self_influence(test_id, init_kwargs, request): # TODO: I think a good naming convention is "test_..." or "test_...". # TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. # TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. -def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request): +def test_explain_stateful( + test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request, tmp_path +): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) @@ -106,7 +101,7 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te explainer = CaptumSimilarity( model=model, model_id="test_id", - cache_dir=os.path.join("./cache", "test_id"), + cache_dir=str(tmp_path), train_dataset=dataset, device="cpu", **method_kwargs,