Skip to content

Commit

Permalink
change output structure
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 21, 2024
1 parent 8e9cecc commit 5cc6689
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 70 deletions.
2 changes: 1 addition & 1 deletion quanda/metrics/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute(self):
"""
Used to aggregate current results and return a metric score.
"""
return torch.cat(self.scores).mean()
return {"score": torch.cat(self.scores).mean().item()}

def reset(self, *args, **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def compute(self, *args, **kwargs):
normalized_curve = torch.cumsum(success_arr * 1.0, dim=0) / len(self.poisoned_indices)
score = torch.trapezoid(normalized_curve) / len(self.poisoned_indices)
return {
"success_arr": success_arr,
"score": score.item(),
"success_arr": success_arr,
"curve": normalized_curve / len(self.poisoned_indices),
}
4 changes: 2 additions & 2 deletions quanda/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def explain_update(
)
self.update(test_data=test_data, explanations=explanations, explanation_targets=explanation_targets)

def compute(self) -> float:
return torch.cat(self.results["scores"]).mean().item()
def compute(self):
return {"score": torch.cat(self.results["scores"]).mean().item()}

def reset(self):
self.results = {"scores": []}
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def compute(self, *args, **kwargs):

clean_accuracy = class_accuracy(self.model, clean_dl, self.device)

return original_accuracy - clean_accuracy
return {"score": (original_accuracy - clean_accuracy)}
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def update(
self.all_top_k_examples = torch.concat((self.all_top_k_examples, top_k_indices), dim=0)

def compute(self, *args, **kwargs):
return len(torch.unique(self.all_top_k_examples))
return {"score": len(torch.unique(self.all_top_k_examples))}

def reset(self, *args, **kwargs):
self.all_top_k_examples = torch.empty(0, self.top_k)
Expand Down
63 changes: 30 additions & 33 deletions tests/explainers/wrappers/test_captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,24 @@ def test_captum_influence_explain_functional(
{"batch_size": 1, "projection_dim": 10, "arnoldi_dim": 10},
),
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{
"batch_size": 1,
"projection_dim": 10,
"arnoldi_dim": 20,
"arnoldi_tol": 2e-1,
"hessian_reg": 2e-3,
"hessian_inverse_tol": 2e-4,
"projection_on_cpu": True,
},
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{
"batch_size": 1,
"projection_dim": 10,
"arnoldi_dim": 20,
"arnoldi_tol": 2e-1,
"hessian_reg": 2e-3,
"hessian_inverse_tol": 2e-4,
"projection_on_cpu": True,
},
),
],
)
def test_captum_arnoldi(
test_id, model, dataset, test_tensor, test_labels, method_kwargs, request
):
def test_captum_arnoldi(test_id, model, dataset, test_tensor, test_labels, method_kwargs, request):
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
test_tensor = request.getfixturevalue(test_tensor)
Expand Down Expand Up @@ -236,21 +234,21 @@ def test_captum_arnoldi(
},
),
(
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{
"batch_size": 1,
"seed": 42,
"projection_dim": 10,
"arnoldi_dim": 20,
"arnoldi_tol": 1e-1,
"hessian_reg": 1e-3,
"hessian_inverse_tol": 1e-4,
"projection_on_cpu": True,
},
"mnist",
"load_mnist_model",
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{
"batch_size": 1,
"seed": 42,
"projection_dim": 10,
"arnoldi_dim": 20,
"arnoldi_tol": 1e-1,
"hessian_reg": 1e-3,
"hessian_inverse_tol": 1e-4,
"projection_on_cpu": True,
},
),
],
)
Expand All @@ -263,7 +261,6 @@ def test_captum_arnoldi_explain_functional(
test_labels = request.getfixturevalue(test_labels)
hessian_dataset = torch.utils.data.Subset(dataset, [0, 1])


explainer_captum = ArnoldiInfluenceFunction(
model=model,
train_dataset=dataset,
Expand Down
14 changes: 8 additions & 6 deletions tests/metrics/test_localization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest

from quanda.explainers import SumAggregator
Expand Down Expand Up @@ -41,13 +43,13 @@ def test_identical_class_metrics(
tda = request.getfixturevalue(explanations)
metric = ClassDetectionMetric(model=model, train_dataset=dataset, device="cpu")
metric.update(test_labels=test_labels, explanations=tda)
score = metric.compute()
score = metric.compute()["score"]
# TODO: introduce a more meaningfull test, where the score is not zero
# Note from Galip:
# one idea could be: a random attributor should get approximately 1/( # of classes).
# With a big test dataset, the probability of failing a truly random test
# should diminish.
assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.localization_metrics
Expand Down Expand Up @@ -89,8 +91,8 @@ def test_identical_subclass_metrics(
device="cpu",
)
metric.update(test_subclasses=test_labels, explanations=tda)
score = metric.compute()
assert score == expected_score
score = metric.compute()["score"]
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.localization_metrics
Expand Down Expand Up @@ -158,6 +160,6 @@ def test_poisoning_detection_metric(
expl_kwargs=expl_kwargs,
device="cpu",
)
score = metric.compute()
score = metric.compute()["score"]

assert score["score"] == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
2 changes: 1 addition & 1 deletion tests/metrics/test_randomization_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_randomization_metric(
else:
metric.update(test_data=test_data, explanations=tda, explanation_targets=test_labels)

out = metric.compute()
out = metric.compute()["score"]
assert (out >= -1.0) & (out <= 1.0), "Test failed."


Expand Down
18 changes: 10 additions & 8 deletions tests/metrics/test_unnamed_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest

from quanda.explainers.wrappers.captum_influence import CaptumSimilarity
Expand Down Expand Up @@ -37,8 +39,8 @@ def test_top_k_overlap_metrics(
explanations = request.getfixturevalue(explanations)
metric = TopKOverlapMetric(model=model, train_dataset=dataset, top_k=top_k, device="cpu")
metric.update(explanations=explanations)
score = metric.compute()
assert score == expected_score
score = metric.compute()["score"]
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.unnamed_metrics
Expand Down Expand Up @@ -135,9 +137,9 @@ def test_dataset_cleaning(
device="cpu",
)

score = metric.compute()
score = metric.compute()["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.unnamed_metrics
Expand Down Expand Up @@ -202,9 +204,9 @@ def test_dataset_cleaning_self_influence_based(
device="cpu",
)

score = metric.compute()
score = metric.compute()["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.unnamed_metrics
Expand Down Expand Up @@ -263,6 +265,6 @@ def test_dataset_cleaning_aggr_based(

metric.update(explanations=explanations)

score = metric.compute()
score = metric.compute()["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
6 changes: 4 additions & 2 deletions tests/toy_benchmarks/localization/test_class_detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest

from quanda.explainers.wrappers import CaptumSimilarity
Expand Down Expand Up @@ -112,6 +114,6 @@ def test_class_detection(
model_id="default_model_id",
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import lightning as L
import pytest

Expand Down Expand Up @@ -166,7 +168,7 @@ def test_mislabeling_detection(
device="cpu",
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.toy_benchmarks
Expand Down Expand Up @@ -240,4 +242,4 @@ def test_mislabeling_detection_generate_from_pl_module(
device="cpu",
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
10 changes: 6 additions & 4 deletions tests/toy_benchmarks/localization/test_subclass_detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import lightning as L
import pytest

Expand Down Expand Up @@ -152,9 +154,9 @@ def test_subclass_detection(
use_predictions=use_pred,
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.toy_benchmarks
Expand Down Expand Up @@ -228,6 +230,6 @@ def test_subclass_detection_generate_lightning_model(
use_predictions=use_pred,
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest

from quanda.explainers.wrappers import CaptumSimilarity
Expand Down Expand Up @@ -112,6 +114,6 @@ def test_model_randomization(
model_id="default_model_id",
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
10 changes: 6 additions & 4 deletions tests/toy_benchmarks/unnamed/test_dataset_cleaning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import lightning as L
import pytest

Expand Down Expand Up @@ -146,9 +148,9 @@ def test_dataset_cleaning(
global_method=global_method,
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.toy_benchmarks
Expand Down Expand Up @@ -218,6 +220,6 @@ def test_dataset_cleaning_generate_from_pl_module(
global_method=global_method,
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)
6 changes: 4 additions & 2 deletions tests/toy_benchmarks/unnamed/test_top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import pytest

from quanda.explainers.wrappers import CaptumSimilarity
Expand Down Expand Up @@ -136,6 +138,6 @@ def test_class_detection(
model_id="default_model_id",
batch_size=batch_size,
device="cpu",
)
)["score"]

assert score == expected_score
assert math.isclose(score, expected_score, abs_tol=0.00001)

0 comments on commit 5cc6689

Please sign in to comment.