From 5cc66894b0733e1500d2578292c1a98d6823926f Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Wed, 21 Aug 2024 22:22:48 +0200 Subject: [PATCH] change output structure --- .../metrics/localization/class_detection.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../randomization/model_randomization.py | 4 +- quanda/metrics/unnamed/dataset_cleaning.py | 2 +- quanda/metrics/unnamed/top_k_overlap.py | 2 +- .../wrappers/test_captum_influence.py | 63 +++++++++---------- tests/metrics/test_localization_metrics.py | 14 +++-- tests/metrics/test_randomization_metrics.py | 2 +- tests/metrics/test_unnamed_metrics.py | 18 +++--- .../localization/test_class_detection.py | 6 +- .../test_mislabeling_detection.py | 6 +- .../localization/test_subclass_detection.py | 10 +-- .../randomization/test_model_randomization.py | 6 +- .../unnamed/test_dataset_cleaning.py | 10 +-- .../unnamed/test_top_k_overlap.py | 6 +- 15 files changed, 83 insertions(+), 70 deletions(-) diff --git a/quanda/metrics/localization/class_detection.py b/quanda/metrics/localization/class_detection.py index 9db53e6a..fa6cce89 100644 --- a/quanda/metrics/localization/class_detection.py +++ b/quanda/metrics/localization/class_detection.py @@ -38,7 +38,7 @@ def compute(self): """ Used to aggregate current results and return a metric score. """ - return torch.cat(self.scores).mean() + return {"score": torch.cat(self.scores).mean().item()} def reset(self, *args, **kwargs): """ diff --git a/quanda/metrics/localization/mislabeling_detection.py b/quanda/metrics/localization/mislabeling_detection.py index 8b58b2b2..7e927e06 100644 --- a/quanda/metrics/localization/mislabeling_detection.py +++ b/quanda/metrics/localization/mislabeling_detection.py @@ -86,7 +86,7 @@ def compute(self, *args, **kwargs): normalized_curve = torch.cumsum(success_arr * 1.0, dim=0) / len(self.poisoned_indices) score = torch.trapezoid(normalized_curve) / len(self.poisoned_indices) return { - "success_arr": success_arr, "score": score.item(), + "success_arr": success_arr, "curve": normalized_curve / len(self.poisoned_indices), } diff --git a/quanda/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py index a72f3eef..9a069767 100644 --- a/quanda/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -84,8 +84,8 @@ def explain_update( ) self.update(test_data=test_data, explanations=explanations, explanation_targets=explanation_targets) - def compute(self) -> float: - return torch.cat(self.results["scores"]).mean().item() + def compute(self): + return {"score": torch.cat(self.results["scores"]).mean().item()} def reset(self): self.results = {"scores": []} diff --git a/quanda/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py index 45c6cf68..003926f1 100644 --- a/quanda/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -151,4 +151,4 @@ def compute(self, *args, **kwargs): clean_accuracy = class_accuracy(self.model, clean_dl, self.device) - return original_accuracy - clean_accuracy + return {"score": (original_accuracy - clean_accuracy)} diff --git a/quanda/metrics/unnamed/top_k_overlap.py b/quanda/metrics/unnamed/top_k_overlap.py index 8f7c7dab..a7299df1 100644 --- a/quanda/metrics/unnamed/top_k_overlap.py +++ b/quanda/metrics/unnamed/top_k_overlap.py @@ -27,7 +27,7 @@ def update( self.all_top_k_examples = torch.concat((self.all_top_k_examples, top_k_indices), dim=0) def compute(self, *args, **kwargs): - return len(torch.unique(self.all_top_k_examples)) + return {"score": len(torch.unique(self.all_top_k_examples))} def reset(self, *args, **kwargs): self.all_top_k_examples = torch.empty(0, self.top_k) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 2bf19b1c..f9e9d3d5 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -168,26 +168,24 @@ def test_captum_influence_explain_functional( {"batch_size": 1, "projection_dim": 10, "arnoldi_dim": 10}, ), ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - "load_mnist_test_samples_1", - "load_mnist_test_labels_1", - { - "batch_size": 1, - "projection_dim": 10, - "arnoldi_dim": 20, - "arnoldi_tol": 2e-1, - "hessian_reg": 2e-3, - "hessian_inverse_tol": 2e-4, - "projection_on_cpu": True, - }, + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + { + "batch_size": 1, + "projection_dim": 10, + "arnoldi_dim": 20, + "arnoldi_tol": 2e-1, + "hessian_reg": 2e-3, + "hessian_inverse_tol": 2e-4, + "projection_on_cpu": True, + }, ), ], ) -def test_captum_arnoldi( - test_id, model, dataset, test_tensor, test_labels, method_kwargs, request -): +def test_captum_arnoldi(test_id, model, dataset, test_tensor, test_labels, method_kwargs, request): model = request.getfixturevalue(model) dataset = request.getfixturevalue(dataset) test_tensor = request.getfixturevalue(test_tensor) @@ -236,21 +234,21 @@ def test_captum_arnoldi( }, ), ( - "mnist", - "load_mnist_model", - "load_mnist_dataset", - "load_mnist_test_samples_1", - "load_mnist_test_labels_1", - { - "batch_size": 1, - "seed": 42, - "projection_dim": 10, - "arnoldi_dim": 20, - "arnoldi_tol": 1e-1, - "hessian_reg": 1e-3, - "hessian_inverse_tol": 1e-4, - "projection_on_cpu": True, - }, + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + { + "batch_size": 1, + "seed": 42, + "projection_dim": 10, + "arnoldi_dim": 20, + "arnoldi_tol": 1e-1, + "hessian_reg": 1e-3, + "hessian_inverse_tol": 1e-4, + "projection_on_cpu": True, + }, ), ], ) @@ -263,7 +261,6 @@ def test_captum_arnoldi_explain_functional( test_labels = request.getfixturevalue(test_labels) hessian_dataset = torch.utils.data.Subset(dataset, [0, 1]) - explainer_captum = ArnoldiInfluenceFunction( model=model, train_dataset=dataset, diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 6e8e24ad..3606b405 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -1,3 +1,5 @@ +import math + import pytest from quanda.explainers import SumAggregator @@ -41,13 +43,13 @@ def test_identical_class_metrics( tda = request.getfixturevalue(explanations) metric = ClassDetectionMetric(model=model, train_dataset=dataset, device="cpu") metric.update(test_labels=test_labels, explanations=tda) - score = metric.compute() + score = metric.compute()["score"] # TODO: introduce a more meaningfull test, where the score is not zero # Note from Galip: # one idea could be: a random attributor should get approximately 1/( # of classes). # With a big test dataset, the probability of failing a truly random test # should diminish. - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.localization_metrics @@ -89,8 +91,8 @@ def test_identical_subclass_metrics( device="cpu", ) metric.update(test_subclasses=test_labels, explanations=tda) - score = metric.compute() - assert score == expected_score + score = metric.compute()["score"] + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.localization_metrics @@ -158,6 +160,6 @@ def test_poisoning_detection_metric( expl_kwargs=expl_kwargs, device="cpu", ) - score = metric.compute() + score = metric.compute()["score"] - assert score["score"] == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 3fe97766..8e762759 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -94,7 +94,7 @@ def test_randomization_metric( else: metric.update(test_data=test_data, explanations=tda, explanation_targets=test_labels) - out = metric.compute() + out = metric.compute()["score"] assert (out >= -1.0) & (out <= 1.0), "Test failed." diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 9e6e8c25..ccf9d979 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -1,3 +1,5 @@ +import math + import pytest from quanda.explainers.wrappers.captum_influence import CaptumSimilarity @@ -37,8 +39,8 @@ def test_top_k_overlap_metrics( explanations = request.getfixturevalue(explanations) metric = TopKOverlapMetric(model=model, train_dataset=dataset, top_k=top_k, device="cpu") metric.update(explanations=explanations) - score = metric.compute() - assert score == expected_score + score = metric.compute()["score"] + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.unnamed_metrics @@ -135,9 +137,9 @@ def test_dataset_cleaning( device="cpu", ) - score = metric.compute() + score = metric.compute()["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.unnamed_metrics @@ -202,9 +204,9 @@ def test_dataset_cleaning_self_influence_based( device="cpu", ) - score = metric.compute() + score = metric.compute()["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.unnamed_metrics @@ -263,6 +265,6 @@ def test_dataset_cleaning_aggr_based( metric.update(explanations=explanations) - score = metric.compute() + score = metric.compute()["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/localization/test_class_detection.py b/tests/toy_benchmarks/localization/test_class_detection.py index 7686f7a6..ff07c207 100644 --- a/tests/toy_benchmarks/localization/test_class_detection.py +++ b/tests/toy_benchmarks/localization/test_class_detection.py @@ -1,3 +1,5 @@ +import math + import pytest from quanda.explainers.wrappers import CaptumSimilarity @@ -112,6 +114,6 @@ def test_class_detection( model_id="default_model_id", batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index b2417955..4ab83337 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -1,3 +1,5 @@ +import math + import lightning as L import pytest @@ -166,7 +168,7 @@ def test_mislabeling_detection( device="cpu", )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.toy_benchmarks @@ -240,4 +242,4 @@ def test_mislabeling_detection_generate_from_pl_module( device="cpu", )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index d5a6a7cc..31c39805 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -1,3 +1,5 @@ +import math + import lightning as L import pytest @@ -152,9 +154,9 @@ def test_subclass_detection( use_predictions=use_pred, batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.toy_benchmarks @@ -228,6 +230,6 @@ def test_subclass_detection_generate_lightning_model( use_predictions=use_pred, batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/randomization/test_model_randomization.py b/tests/toy_benchmarks/randomization/test_model_randomization.py index 540c40db..5fecb74e 100644 --- a/tests/toy_benchmarks/randomization/test_model_randomization.py +++ b/tests/toy_benchmarks/randomization/test_model_randomization.py @@ -1,3 +1,5 @@ +import math + import pytest from quanda.explainers.wrappers import CaptumSimilarity @@ -112,6 +114,6 @@ def test_model_randomization( model_id="default_model_id", batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index b7fa1ce4..84e254dd 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -1,3 +1,5 @@ +import math + import lightning as L import pytest @@ -146,9 +148,9 @@ def test_dataset_cleaning( global_method=global_method, batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) @pytest.mark.toy_benchmarks @@ -218,6 +220,6 @@ def test_dataset_cleaning_generate_from_pl_module( global_method=global_method, batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001) diff --git a/tests/toy_benchmarks/unnamed/test_top_k_overlap.py b/tests/toy_benchmarks/unnamed/test_top_k_overlap.py index bd417a1c..e5d09cc5 100644 --- a/tests/toy_benchmarks/unnamed/test_top_k_overlap.py +++ b/tests/toy_benchmarks/unnamed/test_top_k_overlap.py @@ -1,3 +1,5 @@ +import math + import pytest from quanda.explainers.wrappers import CaptumSimilarity @@ -136,6 +138,6 @@ def test_class_detection( model_id="default_model_id", batch_size=batch_size, device="cpu", - ) + )["score"] - assert score == expected_score + assert math.isclose(score, expected_score, abs_tol=0.00001)