From de348f09b3e0e040465328ba36c425ac3f55cd0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Fri, 9 Aug 2024 19:21:24 +0200 Subject: [PATCH 1/8] TRAK wrapper initial code --- quanda/explainers/wrappers/trak.py | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 quanda/explainers/wrappers/trak.py diff --git a/quanda/explainers/wrappers/trak.py b/quanda/explainers/wrappers/trak.py new file mode 100644 index 00000000..e500fc23 --- /dev/null +++ b/quanda/explainers/wrappers/trak.py @@ -0,0 +1,73 @@ +from trak import TRAKer +from trak.projectors import BasicProjector, CudaProjector, NoOpProjector +from trak.projectors import ProjectionType + +from typing import Literal, Optional, Union +import os +import torch + +from quanda.explainers import BaseExplainer + +TRAKProjectorLiteral=Literal["cuda", "noop", "basic"] +TRAKProjectionTypeLiteral=Literal["rademacher", "normal"] + +class TRAK(BaseExplainer): + def __init__( + self, + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + projector: TRAKProjectorLiteral="basic", + proj_dim: int=128, + proj_type: TRAKProjectionTypeLiteral="normal", + seed: int=42, + batch_size: int=32, + ): + super(TRAK, self).__init__(model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device) + self.dataset=train_dataset + self.batch_size=batch_size + proj_type=ProjectionType.normal if proj_type=="normal" else ProjectionType.rademacher + + number_of_params=0 + for p in list(self.model.sim_parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + number_of_params += nn + + projector_cls = { + "cuda": CudaProjector, + "basic": BasicProjector, + "noop": NoOpProjector + } + + projector_kwargs={ + "grad_dim": number_of_params, + "proj_dim": proj_dim, + "proj_type": proj_type, + "seed": seed, + "device": device + } + projector=projector_cls[projector](**projector_kwargs) + self.traker = TRAKer(model=model, task='image_classification', train_set_size=len(train_dataset), + projector=projector, proj_dim=proj_dim, projector_seed=seed, save_dir=cache_dir) + + #Train the TRAK explainer: featurize the training data + ld=torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) + self.traker.load_checkpoint(self.model.state_dict(),model_id=0) + for (i,(x,y)) in enumerate(iter(ld)): + batch=x.to(self.device), y.to(self.device) + self.traker.featurize(batch=batch,inds=torch.tensor([i*self.batch_size+j for j in range(self.batch_size)])) + self.traker.finalize_features() + + def explain(self, x, targets): + x=x.to(self.device) + self.traker.start_scoring_checkpoint(model_id=0, + checkpoint=self.model.state_dict(), + exp_name='test', + num_targets=x.shape[0]) + self.traker.score(batch=(x,targets), num_samples=x.shape[0]) + return torch.from_numpy(self.traker.finalize_scores(exp_name='test')).T.to(self.device) + From 67c09bd5d7a2a3b4589be5b76c54404dea6bdc6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Mon, 12 Aug 2024 20:32:21 +0200 Subject: [PATCH 2/8] fix imports --- pyproject.toml | 5 + quanda/explainers/__init__.py | 2 +- quanda/explainers/wrappers/__init__.py | 5 + .../explainers/wrappers/captum_influence.py | 2 +- quanda/explainers/wrappers/trak.py | 128 ++++++++++++------ quanda/metrics/__init__.py | 2 +- .../metrics/localization/class_detection.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../randomization/model_randomization.py | 2 +- quanda/metrics/unnamed/dataset_cleaning.py | 2 +- quanda/metrics/unnamed/top_k_overlap.py | 2 +- quanda/toy_benchmarks/__init__.py | 2 +- .../localization/class_detection.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../localization/subclass_detection.py | 2 +- .../randomization/model_randomization.py | 2 +- .../unnamed/dataset_cleaning.py | 2 +- .../toy_benchmarks/unnamed/top_k_overlap.py | 2 +- 18 files changed, 111 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 445834f3..bd802d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dynamic = ["version"] [tool.isort] profile = "black" +extend_skip = ["__init__.py"] line_length = 79 multi_line_output = 3 include_trailing_comma = true @@ -29,6 +30,10 @@ warn_unused_configs = true check_untyped_defs = true #ignore_errors = true # TODO: change this +[[tool.mypy.overrides]] +module = ["trak", "trak.projectors", "fast_jl"] +ignore_missing_imports = true + # Black formatting [tool.black] line-length = 127 diff --git a/quanda/explainers/__init__.py b/quanda/explainers/__init__.py index c682a366..cf2439bc 100644 --- a/quanda/explainers/__init__.py +++ b/quanda/explainers/__init__.py @@ -1,3 +1,4 @@ +from quanda.explainers.base import BaseExplainer from quanda.explainers import utils, wrappers from quanda.explainers.aggregators import ( AbsSumAggregator, @@ -5,7 +6,6 @@ SumAggregator, aggr_types, ) -from quanda.explainers.base import BaseExplainer from quanda.explainers.functional import ExplainFunc, ExplainFuncMini from quanda.explainers.random import RandomExplainer diff --git a/quanda/explainers/wrappers/__init__.py b/quanda/explainers/wrappers/__init__.py index a69d04f4..ba508a03 100644 --- a/quanda/explainers/wrappers/__init__.py +++ b/quanda/explainers/wrappers/__init__.py @@ -11,6 +11,10 @@ captum_tracincp_self_influence, ) +from quanda.explainers.wrappers.trak import ( + TRAK, +) + __all__ = [ "CaptumInfluence", "CaptumSimilarity", @@ -22,4 +26,5 @@ "CaptumTracInCP", "captum_tracincp_explain", "captum_tracincp_self_influence", + "TRAK", ] diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py index e1bdda55..7c52d486 100644 --- a/quanda/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -11,7 +11,7 @@ ArnoldiInfluenceFunction, ) -from quanda.explainers.base import BaseExplainer +from quanda.explainers import BaseExplainer from quanda.explainers.utils import ( explain_fn_from_explainer, self_influence_fn_from_explainer, diff --git a/quanda/explainers/wrappers/trak.py b/quanda/explainers/wrappers/trak.py index e500fc23..18c4243a 100644 --- a/quanda/explainers/wrappers/trak.py +++ b/quanda/explainers/wrappers/trak.py @@ -1,15 +1,26 @@ -from trak import TRAKer -from trak.projectors import BasicProjector, CudaProjector, NoOpProjector -from trak.projectors import ProjectionType +import warnings +from typing import Iterable, Literal, Optional, Sized, Union -from typing import Literal, Optional, Union -import os import torch +from trak import TRAKer +from trak.projectors import ( + BasicProjector, + CudaProjector, + NoOpProjector, + ProjectionType, +) from quanda.explainers import BaseExplainer -TRAKProjectorLiteral=Literal["cuda", "noop", "basic"] -TRAKProjectionTypeLiteral=Literal["rademacher", "normal"] +# from quanda.explainers.utils import ( +# explain_fn_from_explainer, +# self_influence_fn_from_explainer, +# ) + + +TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"] +TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] + class TRAK(BaseExplainer): def __init__( @@ -19,55 +30,88 @@ def __init__( cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], - projector: TRAKProjectorLiteral="basic", - proj_dim: int=128, - proj_type: TRAKProjectionTypeLiteral="normal", - seed: int=42, - batch_size: int=32, + projector: TRAKProjectorLiteral = "check_cuda", + proj_dim: int = 128, + proj_type: TRAKProjectionTypeLiteral = "normal", + seed: int = 42, + batch_size: int = 32, + params_ldr: Optional[Iterable] = None, ): - super(TRAK, self).__init__(model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device) - self.dataset=train_dataset - self.batch_size=batch_size - proj_type=ProjectionType.normal if proj_type=="normal" else ProjectionType.rademacher - - number_of_params=0 + super(TRAK, self).__init__( + model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device + ) + self.dataset = train_dataset + self.proj_dim = proj_dim + self.batch_size = batch_size + proj_type = ProjectionType.normal if proj_type == "normal" else ProjectionType.rademacher + + num_params_for_grad = 0 for p in list(self.model.sim_parameters()): nn = 1 for s in list(p.size()): nn = nn * s - number_of_params += nn - + num_params_for_grad += nn + + # Check if traker was installer with the ["cuda"] option + try: + import fast_jl + + test_gradient = torch.ones(1, num_params_for_grad).cuda() + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + projector = "cuda" + except (ImportError, RuntimeError, AttributeError) as e: + warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}") + warnings.warn("Defaulting to BasicProjector.") + projector = "basic" + projector_cls = { "cuda": CudaProjector, "basic": BasicProjector, - "noop": NoOpProjector + "noop": NoOpProjector, } - - projector_kwargs={ - "grad_dim": number_of_params, + + projector_kwargs = { + "grad_dim": num_params_for_grad, "proj_dim": proj_dim, "proj_type": proj_type, "seed": seed, - "device": device + "device": device, } - projector=projector_cls[projector](**projector_kwargs) - self.traker = TRAKer(model=model, task='image_classification', train_set_size=len(train_dataset), - projector=projector, proj_dim=proj_dim, projector_seed=seed, save_dir=cache_dir) + projector = projector_cls[projector](**projector_kwargs) + self.traker = TRAKer( + model=model, + task="image_classification", + train_set_size=self.dataset_length, + projector=projector, + proj_dim=proj_dim, + projector_seed=seed, + save_dir=cache_dir, + ) - #Train the TRAK explainer: featurize the training data - ld=torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) - self.traker.load_checkpoint(self.model.state_dict(),model_id=0) - for (i,(x,y)) in enumerate(iter(ld)): - batch=x.to(self.device), y.to(self.device) - self.traker.featurize(batch=batch,inds=torch.tensor([i*self.batch_size+j for j in range(self.batch_size)])) + # Train the TRAK explainer: featurize the training data + ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) + self.traker.load_checkpoint(self.model.state_dict(), model_id=0) + for i, (x, y) in enumerate(iter(ld)): + batch = x.to(self.device), y.to(self.device) + self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(self.batch_size)])) self.traker.finalize_features() - def explain(self, x, targets): - x=x.to(self.device) - self.traker.start_scoring_checkpoint(model_id=0, - checkpoint=self.model.state_dict(), - exp_name='test', - num_targets=x.shape[0]) - self.traker.score(batch=(x,targets), num_samples=x.shape[0]) - return torch.from_numpy(self.traker.finalize_scores(exp_name='test')).T.to(self.device) + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) + def explain(self, x, targets): + x = x.to(self.device) + self.traker.start_scoring_checkpoint( + model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=x.shape[0] + ) + self.traker.score(batch=(x, targets), num_samples=x.shape[0]) + return torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device) diff --git a/quanda/metrics/__init__.py b/quanda/metrics/__init__.py index 40b3aa8e..8e5a4005 100644 --- a/quanda/metrics/__init__.py +++ b/quanda/metrics/__init__.py @@ -1,9 +1,9 @@ +from quanda.metrics.base import GlobalMetric, Metric from quanda.metrics import localization, randomization, unnamed from quanda.metrics.aggr_strategies import ( GlobalAggrStrategy, GlobalSelfInfluenceStrategy, ) -from quanda.metrics.base import GlobalMetric, Metric __all__ = [ "Metric", diff --git a/quanda/metrics/localization/class_detection.py b/quanda/metrics/localization/class_detection.py index 828b6778..a3234e80 100644 --- a/quanda/metrics/localization/class_detection.py +++ b/quanda/metrics/localization/class_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class ClassDetectionMetric(Metric): diff --git a/quanda/metrics/localization/mislabeling_detection.py b/quanda/metrics/localization/mislabeling_detection.py index f8055692..9266ed73 100644 --- a/quanda/metrics/localization/mislabeling_detection.py +++ b/quanda/metrics/localization/mislabeling_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric class MislabelingDetectionMetric(GlobalMetric): diff --git a/quanda/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py index 2a584d83..97e5616f 100644 --- a/quanda/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -3,7 +3,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric from quanda.utils.common import get_parent_module_from_name from quanda.utils.functions import CorrelationFnLiterals, correlation_functions diff --git a/quanda/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py index dfd10f5c..d1d95748 100644 --- a/quanda/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -4,7 +4,7 @@ import lightning as L import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric from quanda.utils.common import class_accuracy from quanda.utils.training import BaseTrainer diff --git a/quanda/metrics/unnamed/top_k_overlap.py b/quanda/metrics/unnamed/top_k_overlap.py index e9ef3e6f..b86353dc 100644 --- a/quanda/metrics/unnamed/top_k_overlap.py +++ b/quanda/metrics/unnamed/top_k_overlap.py @@ -1,6 +1,6 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class TopKOverlapMetric(Metric): diff --git a/quanda/toy_benchmarks/__init__.py b/quanda/toy_benchmarks/__init__.py index 7dc47e53..b275c546 100644 --- a/quanda/toy_benchmarks/__init__.py +++ b/quanda/toy_benchmarks/__init__.py @@ -1,4 +1,4 @@ -from quanda.toy_benchmarks import localization, randomization, unnamed from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import localization, randomization, unnamed __all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"] diff --git a/quanda/toy_benchmarks/localization/class_detection.py b/quanda/toy_benchmarks/localization/class_detection.py index 93dc3f0f..3e2d843d 100644 --- a/quanda/toy_benchmarks/localization/class_detection.py +++ b/quanda/toy_benchmarks/localization/class_detection.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.localization import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class ClassDetection(ToyBenchmark): diff --git a/quanda/toy_benchmarks/localization/mislabeling_detection.py b/quanda/toy_benchmarks/localization/mislabeling_detection.py index 0f5313e3..e31f8638 100644 --- a/quanda/toy_benchmarks/localization/mislabeling_detection.py +++ b/quanda/toy_benchmarks/localization/mislabeling_detection.py @@ -8,7 +8,7 @@ from quanda.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_flipping import ( LabelFlippingDataset, ) diff --git a/quanda/toy_benchmarks/localization/subclass_detection.py b/quanda/toy_benchmarks/localization/subclass_detection.py index f3dbe4f7..a2ced0b9 100644 --- a/quanda/toy_benchmarks/localization/subclass_detection.py +++ b/quanda/toy_benchmarks/localization/subclass_detection.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.localization.class_detection import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_grouping import ( ClassToGroupLiterals, LabelGroupingDataset, diff --git a/quanda/toy_benchmarks/randomization/model_randomization.py b/quanda/toy_benchmarks/randomization/model_randomization.py index c6d810fb..589f224b 100644 --- a/quanda/toy_benchmarks/randomization/model_randomization.py +++ b/quanda/toy_benchmarks/randomization/model_randomization.py @@ -6,7 +6,7 @@ from quanda.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.functions import CorrelationFnLiterals diff --git a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py index 99a17364..999a51f5 100644 --- a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.training.trainer import BaseTrainer diff --git a/quanda/toy_benchmarks/unnamed/top_k_overlap.py b/quanda/toy_benchmarks/unnamed/top_k_overlap.py index 4ac0f83f..bbd86a29 100644 --- a/quanda/toy_benchmarks/unnamed/top_k_overlap.py +++ b/quanda/toy_benchmarks/unnamed/top_k_overlap.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.unnamed import TopKOverlapMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class TopKOverlap(ToyBenchmark): From 9dfd4b90f46fe439e96f0ea269fb021f3c1611eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 13 Aug 2024 17:31:37 +0200 Subject: [PATCH 3/8] add trak tests --- quanda/explainers/wrappers/__init__.py | 6 +- quanda/explainers/wrappers/trak.py | 117 ------------------ src/explainers/__init__.py | 18 +++ src/metrics/__init__.py | 14 +++ tests/conftest.py | 12 +- tests/explainers/test_aggregators.py | 23 ++-- .../wrappers/test_captum_influence.py | 4 +- tests/metrics/test_localization_metrics.py | 10 +- tests/metrics/test_randomization_metrics.py | 6 +- tests/metrics/test_unnamed_metrics.py | 10 +- 10 files changed, 77 insertions(+), 143 deletions(-) delete mode 100644 quanda/explainers/wrappers/trak.py create mode 100644 src/explainers/__init__.py create mode 100644 src/metrics/__init__.py diff --git a/quanda/explainers/wrappers/__init__.py b/quanda/explainers/wrappers/__init__.py index ba508a03..a22b4a78 100644 --- a/quanda/explainers/wrappers/__init__.py +++ b/quanda/explainers/wrappers/__init__.py @@ -11,9 +11,7 @@ captum_tracincp_self_influence, ) -from quanda.explainers.wrappers.trak import ( - TRAK, -) +from quanda.explainers.wrappers.trak_wrapper import TRAK, trak_explain, trak_self_influence __all__ = [ "CaptumInfluence", @@ -27,4 +25,6 @@ "captum_tracincp_explain", "captum_tracincp_self_influence", "TRAK", + "trak_explain", + "trak_self_influence", ] diff --git a/quanda/explainers/wrappers/trak.py b/quanda/explainers/wrappers/trak.py deleted file mode 100644 index 18c4243a..00000000 --- a/quanda/explainers/wrappers/trak.py +++ /dev/null @@ -1,117 +0,0 @@ -import warnings -from typing import Iterable, Literal, Optional, Sized, Union - -import torch -from trak import TRAKer -from trak.projectors import ( - BasicProjector, - CudaProjector, - NoOpProjector, - ProjectionType, -) - -from quanda.explainers import BaseExplainer - -# from quanda.explainers.utils import ( -# explain_fn_from_explainer, -# self_influence_fn_from_explainer, -# ) - - -TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"] -TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] - - -class TRAK(BaseExplainer): - def __init__( - self, - model: torch.nn.Module, - model_id: str, - cache_dir: Optional[str], - train_dataset: torch.utils.data.Dataset, - device: Union[str, torch.device], - projector: TRAKProjectorLiteral = "check_cuda", - proj_dim: int = 128, - proj_type: TRAKProjectionTypeLiteral = "normal", - seed: int = 42, - batch_size: int = 32, - params_ldr: Optional[Iterable] = None, - ): - super(TRAK, self).__init__( - model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device - ) - self.dataset = train_dataset - self.proj_dim = proj_dim - self.batch_size = batch_size - proj_type = ProjectionType.normal if proj_type == "normal" else ProjectionType.rademacher - - num_params_for_grad = 0 - for p in list(self.model.sim_parameters()): - nn = 1 - for s in list(p.size()): - nn = nn * s - num_params_for_grad += nn - - # Check if traker was installer with the ["cuda"] option - try: - import fast_jl - - test_gradient = torch.ones(1, num_params_for_grad).cuda() - num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) - projector = "cuda" - except (ImportError, RuntimeError, AttributeError) as e: - warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}") - warnings.warn("Defaulting to BasicProjector.") - projector = "basic" - - projector_cls = { - "cuda": CudaProjector, - "basic": BasicProjector, - "noop": NoOpProjector, - } - - projector_kwargs = { - "grad_dim": num_params_for_grad, - "proj_dim": proj_dim, - "proj_type": proj_type, - "seed": seed, - "device": device, - } - projector = projector_cls[projector](**projector_kwargs) - self.traker = TRAKer( - model=model, - task="image_classification", - train_set_size=self.dataset_length, - projector=projector, - proj_dim=proj_dim, - projector_seed=seed, - save_dir=cache_dir, - ) - - # Train the TRAK explainer: featurize the training data - ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) - self.traker.load_checkpoint(self.model.state_dict(), model_id=0) - for i, (x, y) in enumerate(iter(ld)): - batch = x.to(self.device), y.to(self.device) - self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(self.batch_size)])) - self.traker.finalize_features() - - @property - def dataset_length(self) -> int: - """ - By default, the Dataset class does not always have a __len__ method. - :return: - """ - if isinstance(self.dataset, Sized): - return len(self.dataset) - dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) - return len(dl) - - def explain(self, x, targets): - x = x.to(self.device) - self.traker.start_scoring_checkpoint( - model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=x.shape[0] - ) - self.traker.score(batch=(x, targets), num_samples=x.shape[0]) - return torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device) diff --git a/src/explainers/__init__.py b/src/explainers/__init__.py new file mode 100644 index 00000000..4c81faf4 --- /dev/null +++ b/src/explainers/__init__.py @@ -0,0 +1,18 @@ +from quanda.explainers.base import BaseExplainer +from quanda.explainers import utils, wrappers +from quanda.explainers.functional import ExplainFunc, ExplainFuncMini +from quanda.explainers.random import RandomExplainer +from quanda.explainers.aggregators import BaseAggregator, SumAggregator + + +__all__ = [ + "BaseExplainer", + "RandomExplainer", + "ExplainFunc", + "ExplainFuncMini", + "utils", + "wrappers", + "BaseAggregator", + "SumAggregator", + "AbsSumAggretor", +] diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 00000000..4d3c5123 --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1,14 @@ +from quanda.metrics.base import GlobalMetric, Metric +from quanda.metrics import localization, randomization, unnamed +from quanda.metrics.aggr_strategies import GlobalAggrStrategy, GlobalSelfInfluenceStrategy + + +__all__ = [ + "Metric", + "GlobalMetric", + "GlobalAggrStrategy", + "GlobalSelfInfluenceStrategy", + "randomization", + "localization", + "unnamed", +] diff --git a/tests/conftest.py b/tests/conftest.py index 3c200ccd..635a80b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -201,10 +201,20 @@ def load_mnist_test_labels_1(): @pytest.fixture -def load_mnist_explanations_1(): +def load_mnist_explanations_similarity_1(): return torch.load("tests/assets/mnist_test_suite_1/mnist_SimilarityInfluence_tda.pt") +@pytest.fixture +def load_mnist_explanations_trak_1(): + return torch.load("tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt") + + +@pytest.fixture +def load_mnist_explanations_trak_si_1(): + return torch.load("tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt") + + @pytest.fixture def load_mnist_dataset_explanations(): return torch.rand((MINI_BATCH_SIZE, MINI_BATCH_SIZE)) diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index b9b3dc67..e566d1b6 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -8,10 +8,10 @@ @pytest.mark.parametrize( "test_id, explanations, aggregator, expected", [ - ("mnist0", "load_mnist_explanations_1", AbsSumAggregator, {}), - ("mnist1", "load_mnist_explanations_1", SumAggregator, {}), - ("mnist2", "load_mnist_explanations_1", SumAggregator, {"err_expl": ValueError}), - ("mnist3", "load_mnist_explanations_1", SumAggregator, {"err_reset": ValueError}), + ("mnist0", "load_mnist_explanations_similarity_1", AbsSumAggregator, {}), + ("mnist1", "load_mnist_explanations_similarity_1", SumAggregator, {}), + ("mnist2", "load_mnist_explanations_similarity_1", SumAggregator, {"err_expl": ValueError}), + ("mnist3", "load_mnist_explanations_similarity_1", SumAggregator, {"err_reset": ValueError}), ], ) def test_aggregator_update(test_id, explanations, aggregator, expected, request): @@ -36,7 +36,10 @@ def test_aggregator_update(test_id, explanations, aggregator, expected, request) @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_reset(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) @@ -49,7 +52,10 @@ def test_aggregator_reset(test_id, explanations, aggregator, request): @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_save(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) @@ -62,7 +68,10 @@ def test_aggregator_save(test_id, explanations, aggregator, request): @pytest.mark.aggregators @pytest.mark.parametrize( "test_id, explanations, aggregator", - [("mnist", "load_mnist_explanations_1", AbsSumAggregator), ("mnist", "load_mnist_explanations_1", SumAggregator)], + [ + ("mnist", "load_mnist_explanations_similarity_1", AbsSumAggregator), + ("mnist", "load_mnist_explanations_similarity_1", SumAggregator), + ], ) def test_aggregator_load(test_id, explanations, aggregator, request): explanations = request.getfixturevalue(explanations) diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index 79e70f53..085e8e76 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -88,7 +88,7 @@ def test_self_influence(test_id, init_kwargs, tmp_path): "mnist", "load_mnist_model", "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_samples_1", "load_mnist_test_labels_1", {"layers": "relu_4", "similarity_metric": cosine_similarity}, @@ -131,7 +131,7 @@ def test_captum_influence_explain_stateful( "load_mnist_test_samples_1", "load_mnist_test_labels_1", {"layers": "relu_4", "similarity_metric": cosine_similarity}, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", ), ], ) diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 6e8e24ad..24d0c8ab 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -20,7 +20,7 @@ "load_mnist_dataset", "load_mnist_test_labels_1", 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 0.1, ), ], @@ -61,7 +61,7 @@ def test_identical_class_metrics( "load_mnist_labels", "load_mnist_test_labels_1", 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 0.1, ), ], @@ -101,7 +101,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "self-influence", {"layers": "fc_2", "similarity_metric": cosine_similarity, "model_id": "test", "cache_dir": "cache"}, 0.4921875, @@ -110,7 +110,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", SumAggregator, None, 0.4921875, @@ -119,7 +119,7 @@ def test_identical_subclass_metrics( "mnist", "load_mnist_model", "load_poisoned_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "sum_abs", None, 0.4921875, diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 3fe97766..053fc98f 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -22,7 +22,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "spearman", ), @@ -37,7 +37,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "kendall", ), @@ -52,7 +52,7 @@ "layers": "fc_2", "similarity_metric": cosine_similarity, }, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", "spearman", ), diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 9e6e8c25..7b363f0b 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -17,7 +17,7 @@ "load_mnist_dataset", 3, 8, - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 7, ), ], @@ -54,7 +54,7 @@ def test_top_k_overlap_metrics( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "sum_abs", 50, None, @@ -69,7 +69,7 @@ def test_top_k_overlap_metrics( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", "self-influence", 50, {"layers": "fc_2", "similarity_metric": cosine_similarity, "cache_dir": "cache", "model_id": "test"}, @@ -152,7 +152,7 @@ def test_dataset_cleaning( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 50, {"layers": "fc_2", "similarity_metric": cosine_similarity, "cache_dir": "cache", "model_id": "test"}, 8, @@ -219,7 +219,7 @@ def test_dataset_cleaning_self_influence_based( "torch_cross_entropy_loss_object", 3, "load_mnist_dataset", - "load_mnist_explanations_1", + "load_mnist_explanations_similarity_1", 50, 0.0, ), From b5eecd76c0d10e590336087d1bff4515a4b4752a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 13 Aug 2024 17:45:35 +0200 Subject: [PATCH 4/8] add trak wrapper code --- quanda/explainers/wrappers/trak_wrapper.py | 171 +++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 quanda/explainers/wrappers/trak_wrapper.py diff --git a/quanda/explainers/wrappers/trak_wrapper.py b/quanda/explainers/wrappers/trak_wrapper.py new file mode 100644 index 00000000..f86191ec --- /dev/null +++ b/quanda/explainers/wrappers/trak_wrapper.py @@ -0,0 +1,171 @@ +import warnings +from typing import Any, Iterable, List, Literal, Optional, Sized, Union + +import torch +from trak import TRAKer +from trak.projectors import BasicProjector, CudaProjector, NoOpProjector + +from quanda.explainers import BaseExplainer +from quanda.explainers.utils import ( + explain_fn_from_explainer, + self_influence_fn_from_explainer, +) + +TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"] +TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] + + +class TRAK(BaseExplainer): + def __init__( + self, + model: torch.nn.Module, + train_dataset: torch.utils.data.Dataset, + model_id: str, + cache_dir: Optional[str] = None, + device: Union[str, torch.device] = "cpu", + projector: TRAKProjectorLiteral = "check_cuda", + proj_dim: int = 128, + proj_type: TRAKProjectionTypeLiteral = "normal", + seed: int = 42, + batch_size: int = 32, + params_ldr: Optional[Iterable] = None, + ): + super(TRAK, self).__init__( + model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device + ) + self.dataset = train_dataset + self.proj_dim = proj_dim + self.batch_size = batch_size + self.cache_dir = cache_dir if cache_dir is not None else f"./trak_{model_id}_cache" + + num_params_for_grad = 0 + params_iter = params_ldr if params_ldr is not None else self.model.parameters() + for p in list(params_iter): + nn = 1 + for s in list(p.size()): + nn = nn * s + num_params_for_grad += nn + + # Check if traker was installer with the ["cuda"] option + if projector in ["cuda", "check_cuda"]: + try: + import fast_jl + + test_gradient = torch.ones(1, num_params_for_grad).cuda() + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + projector = "cuda" + except (ImportError, RuntimeError, AttributeError) as e: + warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}") + warnings.warn("Defaulting to BasicProjector.") + projector = "basic" + + projector_cls = { + "cuda": CudaProjector, + "basic": BasicProjector, + "noop": NoOpProjector, + } + + projector_kwargs = { + "grad_dim": num_params_for_grad, + "proj_dim": proj_dim, + "proj_type": proj_type, + "seed": seed, + "device": device, + } + if projector == "cuda": + projector_kwargs["max_batch_size"] = self.batch_size + projector_obj = projector_cls[projector](**projector_kwargs) + self.traker = TRAKer( + model=model, + task="image_classification", + train_set_size=self.dataset_length, + projector=projector_obj, + proj_dim=proj_dim, + projector_seed=seed, + save_dir=self.cache_dir, + device=device, + use_half_precision=False, + ) + + # Train the TRAK explainer: featurize the training data + ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) + self.traker.load_checkpoint(self.model.state_dict(), model_id=0) + for i, (x, y) in enumerate(iter(ld)): + batch = x.to(self.device), y.to(self.device) + self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(x.shape[0])])) + self.traker.finalize_features() + if projector == "basic": + # finalize_features frees memory so projector.proj_matrix needs to be reconstructed + self.traker.projector = projector_cls[projector](**projector_kwargs) + + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) + + def explain(self, test, targets): + test = test.to(self.device) + self.traker.start_scoring_checkpoint( + model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=test.shape[0] + ) + self.traker.score(batch=(test, targets), num_samples=test.shape[0]) + explanations = torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device) + + # os.remove(os.path.join(self.cache_dir, "scores", "test.mmap")) + # os.removedirs(os.path.join(self.cache_dir, "scores")) + + return explanations + + +def trak_explain( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + test_tensor: torch.Tensor, + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + explanation_targets: Optional[Union[List[int], torch.Tensor]] = None, + **kwargs: Any, +) -> torch.Tensor: + return explain_fn_from_explainer( + explainer_cls=TRAK, + model=model, + model_id=model_id, + cache_dir=cache_dir, + test_tensor=test_tensor, + targets=explanation_targets, + train_dataset=train_dataset, + device=device, + **kwargs, + ) + + +def trak_self_influence( + model: torch.nn.Module, + model_id: str, + cache_dir: Optional[str], + train_dataset: torch.utils.data.Dataset, + device: Union[str, torch.device], + batch_size: Optional[int] = 32, + **kwargs: Any, +) -> torch.Tensor: + self_influence_kwargs = { + "batch_size": batch_size, + } + return self_influence_fn_from_explainer( + explainer_cls=TRAK, + model=model, + model_id=model_id, + cache_dir=cache_dir, + train_dataset=train_dataset, + device=device, + self_influence_kwargs=self_influence_kwargs, + **kwargs, + ) From 8867dc74e7d48d537a83d7c4ed4c4cc0d60438cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Tue, 13 Aug 2024 17:52:10 +0200 Subject: [PATCH 5/8] add traker prerequisite --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bd802d92..9878a457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "lightning>=1.4.0", "torchmetrics>=1.4.0", "tqdm>=4.0.0", + "traker>=0.3.2" ] dynamic = ["version"] From 357866cadc3f0bc1ca5fe12395c8f60bdb231fe8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Wed, 14 Aug 2024 23:59:59 +0200 Subject: [PATCH 6/8] add trak tests --- .../explainers/wrappers/test_trak_wrapper.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/explainers/wrappers/test_trak_wrapper.py diff --git a/tests/explainers/wrappers/test_trak_wrapper.py b/tests/explainers/wrappers/test_trak_wrapper.py new file mode 100644 index 00000000..c9abe64d --- /dev/null +++ b/tests/explainers/wrappers/test_trak_wrapper.py @@ -0,0 +1,103 @@ +import pytest +import torch + +from quanda.explainers.wrappers import TRAK, trak_explain, trak_self_influence + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_explanations_trak_1", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + ), + ], +) +# TODO: I think a good naming convention is "test_..." or "test_...". +# TODO: I would call it test_captum_similarity, because it is a test for the CaptumSimilarity class. +# TODO: We could also make the explainer type (e.g. CaptumSimilarity) a param, then it would be test_explainer or something. +def test_trak_wrapper_explain_stateful( + test_id, model, dataset, explanations, test_tensor, test_labels, method_kwargs, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + + explainer = TRAK(model=model, cache_dir=tmp_path, train_dataset=dataset, **method_kwargs) + + explanations = explainer.explain(test=test_tensor, targets=test_labels) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + "load_mnist_explanations_trak_1", + ), + ], +) +def test_trak_wrapper_explain_functional( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + test_tensor = request.getfixturevalue(test_tensor) + test_labels = request.getfixturevalue(test_labels) + explanations_exp = request.getfixturevalue(explanations) + explanations = trak_explain( + model=model, + cache_dir=str(tmp_path), + test_tensor=test_tensor, + train_dataset=dataset, + explanation_targets=test_labels, + device="cpu", + **method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" + + +@pytest.mark.explainers +@pytest.mark.parametrize( + "test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations", + [ + ( + "mnist", + "load_mnist_model", + "load_mnist_dataset", + "load_mnist_test_samples_1", + "load_mnist_test_labels_1", + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + "load_mnist_explanations_trak_si_1", + ), + ], +) +def test_trak_wrapper_explain_self_influence_functional( + test_id, model, dataset, test_tensor, test_labels, method_kwargs, explanations, request, tmp_path +): + model = request.getfixturevalue(model) + dataset = request.getfixturevalue(dataset) + explanations_exp = request.getfixturevalue(explanations) + explanations = trak_self_influence( + model=model, + cache_dir=str(tmp_path), + train_dataset=dataset, + device="cpu", + **method_kwargs, + ) + assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected" From 25280fe8573272f94e6ca56e417fcb9c6cef6dde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 15 Aug 2024 00:19:31 +0200 Subject: [PATCH 7/8] add assets for trak tests --- .../assets/mnist_test_suite_1/mnist_TRAK_tda.pt | Bin 0 -> 1535 bytes .../mnist_test_suite_1/mnist_TRAK_tda_si.pt | Bin 0 -> 1230 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt create mode 100644 tests/assets/mnist_test_suite_1/mnist_TRAK_tda_si.pt diff --git a/tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt b/tests/assets/mnist_test_suite_1/mnist_TRAK_tda.pt new file mode 100644 index 0000000000000000000000000000000000000000..36fbfd92438104f4e6f318a57657202204e05423 GIT binary patch literal 1535 zcmWIWW@cev;NW1u0OAbX47quk#U=3}L5|+>B`JyeDTyVCdIi}zZcgkBQ4r9;lw6Wu zl$@a#Us{rxQ_K}#l$unUnUfMn*}rtgadH;RDwYP z-KR;FC8_x!lii$TU|xm#7?sM&EGfxJ1=-bV`&g0Ezb8~fRsjU~@5HFon^V*H`Q*Jzr5l+juzIb+K`nVSUs-faRp3IrO! zBq{+C17R>&wf*9%>bGCCit1;so^&K`^^g1zEw2fItG4?Hu70eubCr8R`|5c*iK`X5 zT33JnU#I0S&AYm8`jl0gIZoOI#aSy&QVUk=8$4bm(`UOn{_~qv7fTyfxw}ScFWlX^ z`t1H`t7V+{v^&!qR=@vQxKcXc)vBGUcUE7V?Z0~JTVBo3T?!iAJ9#uj6{J^ux4yB$ z%VFY*>L^)FP4O8TeQ()T7`02UFlG3zbF-LP@3Dpc+AC*P=^Rn~zw&+NzST#0ZPt`Z zUC|LgxOR<)gX>zDe~{<~BsQE0in(hlkJKE1i4TNv$DA004SLL#r4|(f6QP@v1vya% zjzeQEkP9x(OAloNx(S4FyFise8{GwZ;KWm$T9lcXlUbDr$`bLJDQ-^c#5%~p)ZEO% z)WX8l+}Og@($dVx+``BJ2n@|k%`D7}jm=EVO$-c;%|NcSwB1>t0dye<2Y54r91AZw zkaL-# XmJ#W)M7`}YFW?%uSX;lPL2Lr)%?Kp9eK*CgBJYd3yA{Z>!J_U*upuEvMn>*ULU0_L*Lvmi-T8T*_v6m7D^-L90j<~u641dYOqhGW zf6I1R;JHb-6L`#%51$gttfEKUUX*w(D}xGu`-fv%wFO3NcWq<~6I!=R6WVy?`w=17 zXJIr-#5-|-V;t;-J`<^Mxlh*^75nsz?AQ(A>^WC{!;!$9!YDVwodjw zwAMaezjFHXwG8v1Z0r5w;YB_MFILt>%^YP%I8J!3E%SW7G))AC(|5xMWsVMB%pf5t zhe4hOpWrz3NchZy%DQ1-ncMkB2OZP=4j2%*&mHNoNkS_s{w~i>qStYwFnbdTajch03)BNNqm)`(|(vaQi*2 b5Y(y;*9Ri}rUHXy$F9LaLpA=u{@LC?V0aAO literal 0 HcmV?d00001 From 09ad3a5afb133e0db9664438286add9a481a6932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Thu, 15 Aug 2024 17:36:10 +0200 Subject: [PATCH 8/8] changes after review --- quanda/explainers/wrappers/trak_wrapper.py | 27 +++++++------------ .../explainers/wrappers/test_trak_wrapper.py | 6 ++--- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/quanda/explainers/wrappers/trak_wrapper.py b/quanda/explainers/wrappers/trak_wrapper.py index f86191ec..153e08f2 100644 --- a/quanda/explainers/wrappers/trak_wrapper.py +++ b/quanda/explainers/wrappers/trak_wrapper.py @@ -1,4 +1,5 @@ import warnings +from importlib.util import find_spec from typing import Any, Iterable, List, Literal, Optional, Sized, Union import torch @@ -11,7 +12,7 @@ self_influence_fn_from_explainer, ) -TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"] +TRAKProjectorLiteral = Literal["cuda", "noop", "basic"] TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] @@ -21,9 +22,9 @@ def __init__( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, model_id: str, - cache_dir: Optional[str] = None, + cache_dir: str, + projector: TRAKProjectorLiteral, device: Union[str, torch.device] = "cpu", - projector: TRAKProjectorLiteral = "check_cuda", proj_dim: int = 128, proj_type: TRAKProjectionTypeLiteral = "normal", seed: int = 42, @@ -41,23 +42,13 @@ def __init__( num_params_for_grad = 0 params_iter = params_ldr if params_ldr is not None else self.model.parameters() for p in list(params_iter): - nn = 1 - for s in list(p.size()): - nn = nn * s - num_params_for_grad += nn - + num_params_for_grad = num_params_for_grad + p.numel() # Check if traker was installer with the ["cuda"] option - if projector in ["cuda", "check_cuda"]: - try: - import fast_jl - - test_gradient = torch.ones(1, num_params_for_grad).cuda() - num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + if projector == "cuda": + if find_spec("fast_jl"): projector = "cuda" - except (ImportError, RuntimeError, AttributeError) as e: - warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}") - warnings.warn("Defaulting to BasicProjector.") + else: + warnings.warn("Could not find cuda installation of TRAK. Defaulting to BasicProjector.") projector = "basic" projector_cls = { diff --git a/tests/explainers/wrappers/test_trak_wrapper.py b/tests/explainers/wrappers/test_trak_wrapper.py index c9abe64d..b3101cc4 100644 --- a/tests/explainers/wrappers/test_trak_wrapper.py +++ b/tests/explainers/wrappers/test_trak_wrapper.py @@ -15,7 +15,7 @@ "load_mnist_explanations_trak_1", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, ), ], ) @@ -47,7 +47,7 @@ def test_trak_wrapper_explain_stateful( "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, "load_mnist_explanations_trak_1", ), ], @@ -82,7 +82,7 @@ def test_trak_wrapper_explain_functional( "load_mnist_dataset", "load_mnist_test_samples_1", "load_mnist_test_labels_1", - {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10}, + {"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"}, "load_mnist_explanations_trak_si_1", ), ],