Skip to content

Commit

Permalink
using tmp_path fixture for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 21, 2024
1 parent c526bf0 commit 802a5c2
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions tests/explainers/wrappers/test_captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -28,7 +27,7 @@
],
)
# TODO: I think a good naming convention is "test_<function_name>..." or "test_<class_name>...".
def test_self_influence(test_id, init_kwargs, request):
def test_self_influence(test_id, init_kwargs, tmp_path):
# TODO: this should be a fixture.
model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())]))

Expand All @@ -38,11 +37,12 @@ def test_self_influence(test_id, init_kwargs, request):
y = torch.randint(0, 10, (100,))
rand_dataset = TensorDataset(X, y)

# Using tmp_path pytest fixtures to create a temporary directory
# TODO: One test should test one thing. This is test 1, ....
self_influence_rank_functional = captum_similarity_self_influence(
model=model,
model_id="0",
cache_dir="temp_captum",
cache_dir=str(tmp_path),
train_dataset=rand_dataset,
init_kwargs=init_kwargs,
device="cpu",
Expand All @@ -53,7 +53,7 @@ def test_self_influence(test_id, init_kwargs, request):
explainer_obj = CaptumSimilarity(
model=model,
model_id="1",
cache_dir="temp_captum2",
cache_dir=str(tmp_path),
train_dataset=rand_dataset,
device="cpu",
**init_kwargs,
Expand All @@ -63,13 +63,6 @@ def test_self_influence(test_id, init_kwargs, request):
# TODO: here we then specifically test self_influence for CaptumSimilarity and should make it explicit in the name.
self_influence_rank_stateful = explainer_obj.self_influence()

# TODO: we check "temp_captum2" but then remove os.path.join(os.getcwd(), "temp_captum2")?
# TODO: is there a reason to fear that the "temp_captum2" folder is not in os.getcwd()?
if os.path.isdir("temp_captum2"):
shutil.rmtree(os.path.join(os.getcwd(), "temp_captum2"))
if os.path.isdir("temp_captum"):
shutil.rmtree(os.path.join(os.getcwd(), "temp_captum"))

# TODO: what if we pass a non-identity model? Then we don't expect torch.linalg.norm(X, dim=-1).argsort()
# TODO: let's put expectations in the parametrisation of tests. We want to test different scenarios,
# and not some super-specific case. This specific case definitely can be tested as well.
Expand All @@ -96,7 +89,9 @@ def test_self_influence(test_id, init_kwargs, request):
# TODO: I think a good naming convention is "test_<function_name>..." or "test_<class_name>...".
# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class.
# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something.
def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request):
def test_explain_stateful(
test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request, tmp_path
):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
Expand All @@ -106,7 +101,7 @@ def test_explain_stateful(test_id, model, dataset, explanations, test_tensor, te
explainer = CaptumSimilarity(
model=model,
model_id="test_id",
cache_dir=os.path.join("./cache", "test_id"),
cache_dir=str(tmp_path),
train_dataset=dataset,
device="cpu",
**method_kwargs,
Expand Down

0 comments on commit 802a5c2

Please sign in to comment.