diff --git a/pyproject.toml b/pyproject.toml index 445834f3..9878a457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,13 @@ dependencies = [ "lightning>=1.4.0", "torchmetrics>=1.4.0", "tqdm>=4.0.0", + "traker>=0.3.2" ] dynamic = ["version"] [tool.isort] profile = "black" +extend_skip = ["__init__.py"] line_length = 79 multi_line_output = 3 include_trailing_comma = true @@ -29,6 +31,10 @@ warn_unused_configs = true check_untyped_defs = true #ignore_errors = true # TODO: change this +[[tool.mypy.overrides]] +module = ["trak", "trak.projectors", "fast_jl"] +ignore_missing_imports = true + # Black formatting [tool.black] line-length = 127 diff --git a/quanda/explainers/__init__.py b/quanda/explainers/__init__.py index c682a366..cf2439bc 100644 --- a/quanda/explainers/__init__.py +++ b/quanda/explainers/__init__.py @@ -1,3 +1,4 @@ +from quanda.explainers.base import BaseExplainer from quanda.explainers import utils, wrappers from quanda.explainers.aggregators import ( AbsSumAggregator, @@ -5,7 +6,6 @@ SumAggregator, aggr_types, ) -from quanda.explainers.base import BaseExplainer from quanda.explainers.functional import ExplainFunc, ExplainFuncMini from quanda.explainers.random import RandomExplainer diff --git a/quanda/explainers/wrappers/__init__.py b/quanda/explainers/wrappers/__init__.py index a69d04f4..a22b4a78 100644 --- a/quanda/explainers/wrappers/__init__.py +++ b/quanda/explainers/wrappers/__init__.py @@ -11,6 +11,8 @@ captum_tracincp_self_influence, ) +from quanda.explainers.wrappers.trak_wrapper import TRAK, trak_explain, trak_self_influence + __all__ = [ "CaptumInfluence", "CaptumSimilarity", @@ -22,4 +24,7 @@ "CaptumTracInCP", "captum_tracincp_explain", "captum_tracincp_self_influence", + "TRAK", + "trak_explain", + "trak_self_influence", ] diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py index e1bdda55..7c52d486 100644 --- a/quanda/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -11,7 +11,7 @@ ArnoldiInfluenceFunction, ) -from quanda.explainers.base import BaseExplainer +from quanda.explainers import BaseExplainer from quanda.explainers.utils import ( explain_fn_from_explainer, self_influence_fn_from_explainer, diff --git a/quanda/explainers/wrappers/trak_wrapper.py b/quanda/explainers/wrappers/trak_wrapper.py new file mode 100644 index 00000000..153e08f2 --- /dev/null +++ b/quanda/explainers/wrappers/trak_wrapper.py @@ -0,0 +1,162 @@ +import warnings +from importlib.util import find_spec +from typing import Any, Iterable, List, Literal, Optional, Sized, Union + +import torch +from trak import TRAKer +from trak.projectors import BasicProjector, CudaProjector, NoOpProjector + +from quanda.explainers import BaseExplainer +from quanda.explainers.utils import ( + explain_fn_from_explainer, + self_influence_fn_from_explainer, +) + +TRAKProjectorLiteral = Literal["cuda", "noop", "basic"] +TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] + + +class TRAK(BaseExplainer): + def __init__( + self, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + model_id: str, + cache_dir: str, + projector: TRAKProjectorLiteral, + device: Union[str, torch.device] = "cpu", + proj_dim: int = 128, + proj_type: TRAKProjectionTypeLiteral = "normal", + seed: int = 42, + batch_size: int = 32, + params_ldr: Optional[Iterable] = None, + ): + super(TRAK, self).__init__( + model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device + ) + self.dataset = train_dataset + self.proj_dim = proj_dim + self.batch_size = batch_size + self.cache_dir = cache_dir if cache_dir is not None else f"./trak_{model_id}_cache" + + num_params_for_grad = 0 + params_iter = params_ldr if params_ldr is not None else self.model.parameters() + for p in list(params_iter): + num_params_for_grad = num_params_for_grad + p.numel() + # Check if traker was installer with the ["cuda"] option + if projector == "cuda": + if find_spec("fast_jl"): + projector = "cuda" + else: + warnings.warn("Could not find cuda installation of TRAK. Defaulting to BasicProjector.") + projector = "basic" + + projector_cls = { + "cuda": CudaProjector, + "basic": BasicProjector, + "noop": NoOpProjector, + } + + projector_kwargs = { + "grad_dim": num_params_for_grad, + "proj_dim": proj_dim, + "proj_type": proj_type, + "seed": seed, + "device": device, + } + if projector == "cuda": + projector_kwargs["max_batch_size"] = self.batch_size + projector_obj = projector_cls[projector](**projector_kwargs) + self.traker = TRAKer( + model=model, + task="image_classification", + train_set_size=self.dataset_length, + projector=projector_obj, + proj_dim=proj_dim, + projector_seed=seed, + save_dir=self.cache_dir, + device=device, + use_half_precision=False, + ) + + # Train the TRAK explainer: featurize the training data + ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) + self.traker.load_checkpoint(self.model.state_dict(), model_id=0) + for i, (x, y) in enumerate(iter(ld)): + batch = x.to(self.device), y.to(self.device) + self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(x.shape[0])])) + self.traker.finalize_features() + if projector == "basic": + # finalize_features frees memory so projector.proj_matrix needs to be reconstructed + self.traker.projector = projector_cls[projector](**projector_kwargs) + + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) + + def explain(self, test, targets): + test = test.to(self.device) + self.traker.start_scoring_checkpoint( + model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=test.shape[0] + ) + self.traker.score(batch=(test, targets), num_samples=test.shape[0]) + explanations = torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device) + + # os.remove(os.path.join(self.cache_dir, "scores", "test.mmap")) + # os.removedirs(os.path.join(self.cache_dir, "scores")) + + return explanations + + +def trak_explain( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + **kwargs: Any, +) -> torch.Tensor: + return explain_fn_from_explainer( + explainer_cls=TRAK, + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=test_tensor, + targets=explanation_targets, + train_dataset=train_dataset, + device=device, + **kwargs, + ) + + +def trak_self_influence( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + batch_size: Optional[int] = 32, + **kwargs: Any, +) -> torch.Tensor: + self_influence_kwargs = { + "batch_size": batch_size, + } + return self_influence_fn_from_explainer( + explainer_cls=TRAK, + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + self_influence_kwargs=self_influence_kwargs, + **kwargs, + ) diff --git a/quanda/metrics/__init__.py b/quanda/metrics/__init__.py index 40b3aa8e..8e5a4005 100644 --- a/quanda/metrics/__init__.py +++ b/quanda/metrics/__init__.py @@ -1,9 +1,9 @@ +from quanda.metrics.base import GlobalMetric, Metric from quanda.metrics import localization, randomization, unnamed from quanda.metrics.aggr_strategies import ( GlobalAggrStrategy, GlobalSelfInfluenceStrategy, ) -from quanda.metrics.base import GlobalMetric, Metric __all__ = [ "Metric", diff --git a/quanda/metrics/localization/class_detection.py b/quanda/metrics/localization/class_detection.py index 828b6778..a3234e80 100644 --- a/quanda/metrics/localization/class_detection.py +++ b/quanda/metrics/localization/class_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class ClassDetectionMetric(Metric): diff --git a/quanda/metrics/localization/mislabeling_detection.py b/quanda/metrics/localization/mislabeling_detection.py index f8055692..9266ed73 100644 --- a/quanda/metrics/localization/mislabeling_detection.py +++ b/quanda/metrics/localization/mislabeling_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric class MislabelingDetectionMetric(GlobalMetric): diff --git a/quanda/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py index 2a584d83..97e5616f 100644 --- a/quanda/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -3,7 +3,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric from quanda.utils.common import get_parent_module_from_name from quanda.utils.functions import CorrelationFnLiterals, correlation_functions diff --git a/quanda/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py index dfd10f5c..d1d95748 100644 --- a/quanda/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -4,7 +4,7 @@ import lightning as L import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric from quanda.utils.common import class_accuracy from quanda.utils.training import BaseTrainer diff --git a/quanda/metrics/unnamed/top_k_overlap.py b/quanda/metrics/unnamed/top_k_overlap.py index e9ef3e6f..b86353dc 100644 --- a/quanda/metrics/unnamed/top_k_overlap.py +++ b/quanda/metrics/unnamed/top_k_overlap.py @@ -1,6 +1,6 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class TopKOverlapMetric(Metric): diff --git a/quanda/toy_benchmarks/__init__.py b/quanda/toy_benchmarks/__init__.py index 7dc47e53..b275c546 100644 --- a/quanda/toy_benchmarks/__init__.py +++ b/quanda/toy_benchmarks/__init__.py @@ -1,4 +1,4 @@ -from quanda.toy_benchmarks import localization, randomization, unnamed from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import localization, randomization, unnamed __all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"] diff --git a/quanda/toy_benchmarks/localization/class_detection.py b/quanda/toy_benchmarks/localization/class_detection.py index 93dc3f0f..3e2d843d 100644 --- a/quanda/toy_benchmarks/localization/class_detection.py +++ b/quanda/toy_benchmarks/localization/class_detection.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.localization import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class ClassDetection(ToyBenchmark): diff --git a/quanda/toy_benchmarks/localization/mislabeling_detection.py b/quanda/toy_benchmarks/localization/mislabeling_detection.py index 0f5313e3..e31f8638 100644 --- a/quanda/toy_benchmarks/localization/mislabeling_detection.py +++ b/quanda/toy_benchmarks/localization/mislabeling_detection.py @@ -8,7 +8,7 @@ from quanda.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_flipping import ( LabelFlippingDataset, ) diff --git a/quanda/toy_benchmarks/localization/subclass_detection.py b/quanda/toy_benchmarks/localization/subclass_detection.py index f3dbe4f7..a2ced0b9 100644 --- a/quanda/toy_benchmarks/localization/subclass_detection.py +++ b/quanda/toy_benchmarks/localization/subclass_detection.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.localization.class_detection import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_grouping import ( ClassToGroupLiterals, LabelGroupingDataset, diff --git a/quanda/toy_benchmarks/randomization/model_randomization.py b/quanda/toy_benchmarks/randomization/model_randomization.py index c6d810fb..589f224b 100644 --- a/quanda/toy_benchmarks/randomization/model_randomization.py +++ b/quanda/toy_benchmarks/randomization/model_randomization.py @@ -6,7 +6,7 @@ from quanda.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.functions import CorrelationFnLiterals diff --git a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py index 99a17364..999a51f5 100644 --- a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.training.trainer import BaseTrainer diff --git a/quanda/toy_benchmarks/unnamed/top_k_overlap.py b/quanda/toy_benchmarks/unnamed/top_k_overlap.py index 4ac0f83f..bbd86a29 100644 --- a/quanda/toy_benchmarks/unnamed/top_k_overlap.py +++ b/quanda/toy_benchmarks/unnamed/top_k_overlap.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.unnamed import TopKOverlapMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class TopKOverlap(ToyBenchmark): diff --git a/src/explainers/__init__.py b/src/explainers/__init__.py new file mode 100644 index 00000000..4c81faf4 --- /dev/null +++ b/src/explainers/__init__.py @@ -0,0 +1,18 @@ +from quanda.explainers.base import BaseExplainer +from quanda.explainers import utils, wrappers +from quanda.explainers.functional import ExplainFunc, ExplainFuncMini +from quanda.explainers.random import RandomExplainer +from quanda.explainers.aggregators import BaseAggregator, SumAggregator + + +__all__ = [ + "BaseExplainer", + "RandomExplainer", + "ExplainFunc", + "ExplainFuncMini", + "utils", + "wrappers", + "BaseAggregator", + "SumAggregator", + "AbsSumAggretor", +] diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 00000000..4d3c5123 --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1,14 @@ +from quanda.metrics.base import GlobalMetric, Metric +from quanda.metrics import localization, randomization, unnamed +from quanda.metrics.aggr_strategies import GlobalAggrStrategy, GlobalSelfInfluenceStrategy + + +__all__ = [ + "Metric", + "GlobalMetric", + "GlobalAggrStrategy", + "GlobalSelfInfluenceStrategy", + "randomization", + "localization", + "unnamed", +] diff --git a/tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt b/tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt new file mode 100644 index 00000000..36fbfd92 Binary files /dev/null and b/tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt differ diff --git a/tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt b/tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt new file mode 100644 index 00000000..39f36ebc Binary files /dev/null and b/tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt differ diff --git a/tests/conftest.py b/tests/conftest.py index 3c200ccd..635a80b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -201,10 +201,20 @@ def load_mnist_test_labels_1(): @pytest.fixture -def load_mnist_explanations_1(): +def load_mnist_explanations_similarity_1(): return torch.load("tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") +@pytest.fixture +def load_mnist_explanations_trak_1(): + return torch.load("tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt") + + +@pytest.fixture +def load_mnist_explanations_trak_si_1(): + return torch.load("tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt") + + @pytest.fixture def load_mnist_dataset_explanations(): return torch.rand((MINI_BATCH_SIZE, MINI_BATCH_SIZE)) diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index b9b3dc67..e566d1b6 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -8,10 +8,10 @@ @pytest.mark.parametrize( "test_id, explanations, aggregator, expected", [ - ("mnist0", "load_mnist_explanations_1", AbsSumAggregator, {}), - ("mnist1", "load_mnist_explanations_1", SumAggregator, {}), - ("mnist2", "load_mnist_explanations_1", SumAggregator, {"err_expl": ValueError}), - ("mnist3", "load_mnist_explanations_1", SumAggregator, {"err_reset": ValueError}), + ("mnist0", "load_mnist_explanations_similarity_1", AbsSumAggregator, {}), + ("mnist1", "load_mnist_explanations_similarity_1", SumAggregator, {}), + ("mnist2", "load_mnist_explanations_similarity_1", SumAggregator, {"err_expl": ValueError}), + ("mnist3", "load_mnist_explanations_similarity_1", SumAggregator, {"err_reset": ValueError}), ], ) def test_aggregator_update(test_id, explanations, aggregator, expected, request): @@ -36,7 +36,10 @@ def test_aggregator_update(test_id, explanations, aggregator, expected, request) @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_reset(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) @@ -49,7 +52,10 @@ def test_aggregator_reset(test_id, explanations, aggregator, request): @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_save(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) @@ -62,7 +68,10 @@ def test_aggregator_save(test_id, explanations, aggregator, request): @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_load(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 79e70f53..085e8e76 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -88,7 +88,7 @@ def test_self_influence(test_id, init_kwargs, tmp_path): "mnist", "load_mnist_model", "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_samples_1", "load_mnist_test_labels_1", {"layers": "relu_4", "similarity_metric": cosine_similarity}, @@ -131,7 +131,7 @@ def test_captum_influence_explain_stateful( "load_mnist_test_samples_1", "load_mnist_test_labels_1", {"layers": "relu_4", "similarity_metric": cosine_similarity}, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", ), ], ) diff --git a/tests/explainers/wrappers/test_trak_wrapper.py b/tests/explainers/wrappers/test_trak_wrapper.py new file mode 100644 index 00000000..b3101cc4 --- /dev/null +++ b/tests/explainers/wrappers/test_trak_wrapper.py @@ -0,0 +1,103 @@ +import pytest +import torch + +from quanda.explainers.wrappers import TRAK, trak_explain, trak_self_influence + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_explanations_trak_1", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, + ), + ], +) +# TODO: I think a good naming convention is "test_..." or "test_...". +# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. +# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. +def test_trak_wrapper_explain_stateful( + test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + + explainer = TRAK(model=model, cache_dir=tmp_path, train_dataset=dataset, **method_kwargs) + + explanations = explainer.explain(test=test_tensor, targets=test_labels) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, + "load_mnist_explanations_trak_1", + ), + ], +) +def test_trak_wrapper_explain_functional( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + explanations = trak_explain( + model=model, + cache_dir=str(tmp_path), + test_tensor=test_tensor, + train_dataset=dataset, + explanation_targets=test_labels, + device="cpu", + **method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, + "load_mnist_explanations_trak_si_1", + ), + ], +) +def test_trak_wrapper_explain_self_influence_functional( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + explanations_exp = request.getfixturevalue(explanations) + explanations = trak_self_influence( + model=model, + cache_dir=str(tmp_path), + train_dataset=dataset, + device="cpu", + **method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 6e8e24ad..24d0c8ab 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -20,7 +20,7 @@ "load_mnist_dataset", "load_mnist_test_labels_1", 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 0.1, ), ], @@ -61,7 +61,7 @@ def test_identical_class_metrics( "load_mnist_labels", "load_mnist_test_labels_1", 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 0.1, ), ], @@ -101,7 +101,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "self-influence", {"layers": "fc_2", "similarity_metric": cosine_similarity, "model_id": "test", "cache_dir": "cache"}, 0.4921875, @@ -110,7 +110,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", SumAggregator, None, 0.4921875, @@ -119,7 +119,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "sum_abs", None, 0.4921875, diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 3fe97766..053fc98f 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -22,7 +22,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "spearman", ), @@ -37,7 +37,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "kendall", ), @@ -52,7 +52,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "spearman", ), diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 9e6e8c25..7b363f0b 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -17,7 +17,7 @@ "load_mnist_dataset", 3, 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 7, ), ], @@ -54,7 +54,7 @@ def test_top_k_overlap_metrics( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "sum_abs", 50, None, @@ -69,7 +69,7 @@ def test_top_k_overlap_metrics( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "self-influence", 50, {"layers": "fc_2", "similarity_metric": cosine_similarity, "cache_dir": "cache", "model_id": "test"}, @@ -152,7 +152,7 @@ def test_dataset_cleaning( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 50, {"layers": "fc_2", "similarity_metric": cosine_similarity, "cache_dir": "cache", "model_id": "test"}, 8, @@ -219,7 +219,7 @@ def test_dataset_cleaning_self_influence_based( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 50, 0.0, ),