Skip to content

Commit

Permalink
fixing some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 20, 2024
1 parent 5cded75 commit ad2834c
Show file tree
Hide file tree
Showing 17 changed files with 206 additions and 239 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SHELL = /bin/bash
.PHONY: style
style:
black .
flake8 .
flake8 . --pytest-parametrize-names-type=csv
python -m isort .
rm -f .coverage
rm -f .coverage.*
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh
"coverage>=7.2.3",
"flake8>=6.0.0",
"pytest<=7.4.4",
"flake8-pytest-style>=1.3.2",
"pytest-cov>=4.0.0",
"pytest-mock==3.10.0",
"pre-commit>=3.2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def reset(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def load_state_dict(self, state_dict: dict, *args, **kwargs):
def load_state_dict(self, state_dict: dict):
"""
Used to load the metric state.
"""
Expand Down
50 changes: 0 additions & 50 deletions src/metrics/functional.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
device,
device: str,
*args,
**kwargs,
):
super().__init__(model, train_dataset, device, *args, **kwargs)
super().__init__(model=model, train_dataset=train_dataset, device=device, *args, **kwargs)
self.scores = []

def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
Expand All @@ -27,8 +27,8 @@ def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
top_one_xpl_indices = explanations.argmax(dim=1)
top_one_xpl_targets = torch.stack([self.train_dataset[i][1] for i in top_one_xpl_indices])

score = (test_labels == top_one_xpl_targets) * 1.0
self.scores.append(score)
scores = (test_labels == top_one_xpl_targets) * 1.0
self.scores.append(scores)

def compute(self):
"""
Expand Down
39 changes: 21 additions & 18 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@ def __init__(
train_dataset=self.train_dataset,
)

self.results = {"rank_correlations": []}
self.results = {"scores": []}

# TODO: create a validation utility function
if isinstance(correlation_fn, str) and correlation_fn in correlation_functions:
self.correlation_measure = correlation_functions.get(correlation_fn)
self.corr_measure = correlation_functions.get(correlation_fn)
elif callable(correlation_fn):
self.correlation_measure = correlation_fn
self.corr_measure = correlation_fn
else:
raise ValueError(
f"Invalid correlation function: expected one of {list(correlation_functions.keys())} or"
f"a Callable, but got {self.correlation_measure}."
f"a Callable, but got {self.corr_measure}."
)

def update(
Expand All @@ -82,37 +83,39 @@ def update(
explanations: torch.Tensor,
explanation_targets: torch.Tensor,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
rand_explanations = self.explain_fn(
model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=device
model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device
)
corrs = self.correlation_measure(explanations, rand_explanations)
self.results["rank_correlations"].append(corrs)
corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

def compute(self):
return torch.cat(self.results["rank_correlations"]).mean()
return torch.cat(self.results["scores"]).mean()

def reset(self):
self.results = {"rank_correlations": []}
self.results = {"scores": []}
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(self.model)

def state_dict(self):
state_dict = {
"results_dict": self.results,
"random_model_state_dict": self.model.state_dict(),
"seed": self.seed,
"generator_state": self.generator.get_state(),
"explain_fn": self.explain_fn,
"rnd_model": self.model.state_dict(),
# Note to Galip: I suggest removing this, because those are explicitly passed
# as init arguments and this is an unexpected side effect if we overwrite them.
# Plus, we only ever use seed to randomize the model once.
# "seed": self.seed,
# "generator_state": self.generator.get_state(),
# "explain_fn": self.explain_fn,
}
return state_dict

def load_state_dict(self, state_dict: dict):
self.results = state_dict["results_dict"]
self.seed = state_dict["seed"]
self.explain_fn = state_dict["explain_fn"]
self.rand_model.load_state_dict(state_dict["random_model_state_dict"])
self.generator.set_state(state_dict["generator_state"])
self.rand_model.load_state_dict(state_dict["rnd_model"])
# self.seed = state_dict["seed"]
# self.explain_fn = state_dict["explain_fn"]
# self.generator.set_state(state_dict["generator_state"])

def _randomize_model(self, model: torch.nn.Module):
rand_model = copy.deepcopy(model)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_datasets(dataset_name, dataset_type, **kwparams):
elif dataset_type == "mark":
ds = MarkDataset(ds, only_train=only_train)
evalds = MarkDataset(evalds, only_train=only_train)
assert ds is not None and evalds is not None
# assert ds is not None and evalds is not None
return ds, evalds


Expand Down
24 changes: 12 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,38 @@
RANDOM_SEED = 42


@pytest.fixture()
@pytest.fixture
def load_dataset():
x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)])
y = torch.tensor([0, 1, 0]).long()
return torch.utils.data.TensorDataset(x, y)


@pytest.fixture()
@pytest.fixture
def load_rand_tensor():
return torch.rand(10, 10).float()


@pytest.fixture()
@pytest.fixture
def load_rand_test_predictions():
return torch.randint(0, 10, (10000,))


@pytest.fixture()
@pytest.fixture
def load_mnist_model():
"""Load a pre-trained LeNet classification model (architecture at quantus/helpers/models)."""
model = LeNet()
model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle))
return model


@pytest.fixture()
@pytest.fixture
def load_init_mnist_model():
"""Load a not trained LeNet classification model (architecture at quantus/helpers/models)."""
return LeNet()


@pytest.fixture()
@pytest.fixture
def load_mnist_dataset():
"""Load a batch of MNIST digits: inputs and outputs to use for testing."""
x_batch = (
Expand All @@ -58,7 +58,7 @@ def load_mnist_dataset():
return dataset


@pytest.fixture()
@pytest.fixture
def load_mnist_dataloader():
"""Load a batch of MNIST digits: inputs and outputs to use for testing."""
x_batch = (
Expand All @@ -72,26 +72,26 @@ def load_mnist_dataloader():
return dataloader


@pytest.fixture()
@pytest.fixture
def load_mnist_test_samples_1():
return torch.load("tests/assets/mnist_test_suite_1/test_dataset.pt")


@pytest.fixture()
@pytest.fixture
def load_mnist_test_labels_1():
return torch.load("tests/assets/mnist_test_suite_1/test_labels.pt")


@pytest.fixture()
@pytest.fixture
def load_mnist_explanations_1():
return torch.load("tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt")


@pytest.fixture()
@pytest.fixture
def torch_cross_entropy_loss_object():
return torch.nn.CrossEntropyLoss()


@pytest.fixture()
@pytest.fixture
def torch_sgd_optimizer():
return functools.partial(torch.optim.SGD, lr=0.01, momentum=0.9)
12 changes: 4 additions & 8 deletions tests/explainers/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@

@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
"test_id, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_sum_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
def test_sum_aggregator(test_id, explanations, request):
explanations = request.getfixturevalue(explanations)
aggregator = SumAggregator()
aggregator.update(explanations)
Expand All @@ -26,17 +24,15 @@ def test_sum_aggregator(test_id, dataset, explanations, request):

@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
"test_id, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_abs_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
def test_abs_aggregator(test_id, explanations, request):
explanations = request.getfixturevalue(explanations)
aggregator = AbsSumAggregator()
aggregator.update(explanations)
Expand Down
10 changes: 6 additions & 4 deletions tests/explainers/test_base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@

@pytest.mark.explainers
@pytest.mark.parametrize(
"test_id, model, dataset, method_kwargs",
"test_id, model, dataset, explanations, method_kwargs",
[
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_explanations_1",
{"layers": "relu_4", "similarity_metric": cosine_similarity},
),
],
)
def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, mocker, request):
def test_base_explain_self_influence(test_id, model, dataset, explanations, method_kwargs, mocker, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)

BaseExplainer.__abstractmethods__ = set()
explainer = BaseExplainer(
Expand All @@ -34,9 +36,9 @@ def test_base_explain_self_influence(test_id, model, dataset, method_kwargs, moc
**method_kwargs,
)

# Patch the method
# Patch the method, because BaseExplainer has an abstract explain method.
def mock_explain(test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None):
return torch.ones((test.shape[0], dataset.__len__()))
return explanations

mocker.patch.object(explainer, "explain", wraps=mock_explain)

Expand Down
Loading

0 comments on commit ad2834c

Please sign in to comment.